diff --git a/Cargo.lock b/Cargo.lock index 75f67709..1a0a5eed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,28 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -4003,6 +4025,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-test", "tracing", "tracing-subscriber", ] @@ -5122,6 +5145,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.12" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index ea910406..ed666fe4 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -29,12 +29,13 @@ rig-derive = { path = "./rig-core-derive", optional = true } anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" +tokio-test = "0.4.4" [features] derive = ["dep:rig-derive"] [[test]] -name = "embeddable_macro" +name = "extract_embedding_fields_macro" required-features = ["derive"] [[example]] @@ -47,4 +48,4 @@ required-features = ["derive"] [[example]] name = "vector_search_cohere" -required-features = ["derive"] \ No newline at end of file +required-features = ["derive"] diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 3f9f0b1b..576491d5 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -252,12 +252,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ab1387a1..9829089f 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -5,13 +5,14 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - Embeddable, + ExtractEmbeddingFields, }; use serde::Serialize; // Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] +// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition` +// and tag that field with `#[embed]`. +#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] @@ -26,6 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ FakeDefinition { @@ -54,14 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index bdad5109..c140da15 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -161,12 +161,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 5aebe12d..4777b42c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -57,14 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(model); let results = index diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 54adc598..6d966004 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -58,14 +58,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(search_model); let results = index diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 86bb13ad..39b72018 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator syn::Resu let (custom_targets, custom_target_size) = data_struct.custom()?; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. + // ie. do not implement `ExtractEmbeddingFields` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( name, @@ -34,7 +34,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu _ => { return Err(syn::Error::new_spanned( input, - "Embeddable derive macro should only be used on structs", + "ExtractEmbeddingFields derive macro should only be used on structs", )) } }; @@ -42,18 +42,18 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: Embeddable trait is imported with the macro. + // Note: `ExtractEmbeddingFields` trait is imported with the macro. - impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Error = rig::embeddings::embeddable::EmbeddableError; + impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { + type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { + fn extract_embedding_fields(&self) -> Result, Self::Error> { #target_stream; rig::OneOrMany::merge( embed_targets.into_iter() .collect::, _>>()? - ).map_err(rig::embeddings::embeddable::EmbeddableError::new) + ).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new) } } }; @@ -87,7 +87,7 @@ impl StructParser for DataStruct { if !embed_targets.is_empty() { ( quote! { - vec![#(#embed_targets.embeddable()),*] + vec![#(#embed_targets.extract_embedding_fields()),*] }, embed_targets.len(), ) diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index d28a0d78..8ad69a65 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -4,18 +4,18 @@ use syn::{parse_macro_input, DeriveInput}; mod basic; mod custom; -mod embeddable; +mod extract_embedding_fields; pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embeddable, attributes(embed))] +#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - embeddable::expand_derive_embedding(&mut input) + extract_embedding_fields::expand_derive_embedding(&mut input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 887dbf0e..b1b310bb 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,22 +1,22 @@ //! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! Only types that implement the [ExtractEmbeddingFields] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use crate::{ - embeddings::{Embeddable, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{Embedding, EmbeddingError, EmbeddingModel, ExtractEmbeddingFields}, OneOrMany, }; /// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { model: M, documents: Vec<(T, OneOrMany)>, } -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,18 +25,18 @@ impl EmbeddingsBuilder { } } - /// Add a document that implements `Embeddable` to the builder. + /// Add a document that implements `ExtractEmbeddingFields` to the builder. pub fn document(mut self, document: T) -> Result { - let embed_targets = document.embeddable()?; + let embed_targets = document.extract_embedding_fields()?; self.documents.push((document, embed_targets)); Ok(self) } - /// Add many documents that implement `Embeddable` to the builder. + /// Add many documents that implement `ExtractEmbeddingFields` to the builder. pub fn documents(mut self, documents: Vec) -> Result { for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; + let embed_targets = doc.extract_embedding_fields()?; self.documents.push((doc, embed_targets)); } @@ -53,13 +53,13 @@ impl EmbeddingsBuilder { /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// Embeddable, +/// ExtractEmbeddingFields, /// }; /// use serde::{Deserialize, Serialize}; /// /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] /// struct FakeDefinition { /// id: String, /// word: String, @@ -103,7 +103,7 @@ impl EmbeddingsBuilder { /// .build() /// .await?; /// ``` -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). pub async fn build(self) -> Result)>, EmbeddingError> { @@ -171,8 +171,10 @@ impl EmbeddingsBuilder, } - impl Embeddable for FakeDefinition { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinition { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result, Self::Error> { + crate::OneOrMany::many(self.definitions.clone()) + .map_err(ExtractEmbeddingFieldsError::new) } } @@ -253,10 +256,10 @@ mod tests { definition: String, } - impl Embeddable for FakeDefinitionSingle { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinitionSingle { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { + fn extract_embedding_fields(&self) -> Result, Self::Error> { Ok(crate::OneOrMany::one(self.definition.clone())) } } diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs deleted file mode 100644 index f5a69fd6..00000000 --- a/rig-core/src/embeddings/embeddable.rs +++ /dev/null @@ -1,160 +0,0 @@ -//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! # Example -//! ```rust -//! use std::env; -//! -//! use serde::{Deserialize, Serialize}; -//! use rig::OneOrMany; -//! -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! definition: String, -//! } -//! -//! let fake_definition = FakeDefinition { -//! id: "doc1".to_string(), -//! word: "hello".to_string(), -//! definition: "used as a greeting or to begin a conversation".to_string() -//! }; -//! -//! impl Embeddable for FakeDefinition { -//! type Error = anyhow::Error; -//! -//! fn embeddable(&self) -> Result, Self::Error> { -//! // Embeddigns only need to be generated for `definition` field. -//! // Select it from the struct and return it as a single item. -//! Ok(OneOrMany::one(self.definition.clone())) -//! } -//! } -//! ``` - -use crate::one_or_many::OneOrMany; - -/// Error type used for when the `embeddable` method fails. -/// Used by default implementations of `Embeddable` for common types. -#[derive(Debug, thiserror::Error)] -#[error("{0}")] -pub struct EmbeddableError(#[from] Box); - -impl EmbeddableError { - pub fn new(error: E) -> Self { - EmbeddableError(Box::new(error)) - } -} - -/// Trait for types that can be embedded. -/// The `embeddable` method returns a `OneOrMany` which contains strings for which embeddings will be generated by the embeddings builder. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. -pub trait Embeddable { - type Error: std::error::Error + Sync + Send + 'static; - - fn embeddable(&self) -> Result, Self::Error>; -} - -// ================================================================ -// Implementations of Embeddable for common types -// ================================================================ -impl Embeddable for String { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.clone())) - } -} - -impl Embeddable for i8 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i16 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i128 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for bool { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for char { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for serde_json::Value { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one( - serde_json::to_string(self).map_err(EmbeddableError::new)?, - )) - } -} - -impl Embeddable for Vec { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - let items = self - .iter() - .map(|item| item.embeddable()) - .collect::, _>>() - .map_err(EmbeddableError::new)?; - - OneOrMany::merge(items).map_err(EmbeddableError::new) - } -} diff --git a/rig-core/src/embeddings/extract_embedding_fields.rs b/rig-core/src/embeddings/extract_embedding_fields.rs new file mode 100644 index 00000000..e62d43c6 --- /dev/null +++ b/rig-core/src/embeddings/extract_embedding_fields.rs @@ -0,0 +1,163 @@ +//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. + +use crate::one_or_many::OneOrMany; + +/// Error type used for when the `extract_embedding_fields` method fails. +/// Used by default implementations of `ExtractEmbeddingFields` for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct ExtractEmbeddingFieldsError(#[from] Box); + +impl ExtractEmbeddingFieldsError { + pub fn new(error: E) -> Self { + ExtractEmbeddingFieldsError(Box::new(error)) + } +} + +/// Derive this trait for structs whose fields need to be converted to vector embeddings. +/// The `extract_embedding_fields` method returns a `OneOrMany`. This function extracts the fields that need to be embedded and returns them as a list of strings. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl ExtractEmbeddingFields for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn extract_embedding_fields(&self) -> Result, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` +pub trait ExtractEmbeddingFields { + type Error: std::error::Error + Sync + Send + 'static; + + fn extract_embedding_fields(&self) -> Result, Self::Error>; +} + +// ================================================================ +// Implementations of ExtractEmbeddingFields for common types +// ================================================================ +impl ExtractEmbeddingFields for String { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.clone())) + } +} + +impl ExtractEmbeddingFields for i8 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i16 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i128 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for bool { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for char { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for serde_json::Value { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one( + serde_json::to_string(self).map_err(ExtractEmbeddingFieldsError::new)?, + )) + } +} + +impl ExtractEmbeddingFields for Vec { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + let items = self + .iter() + .map(|item| item.extract_embedding_fields()) + .collect::, _>>() + .map_err(ExtractEmbeddingFieldsError::new)?; + + OneOrMany::merge(items).map_err(ExtractEmbeddingFieldsError::new) + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 763e0f30..37323cf5 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,10 +4,11 @@ //! and document similarity. pub mod builder; -pub mod embeddable; pub mod embedding; +pub mod extract_embedding_fields; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embeddable::Embeddable; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use extract_embedding_fields::ExtractEmbeddingFields; +pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 139b11b8..7550038f 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,7 +1,7 @@ -use crate::{tool::ToolEmbeddingDyn, Embeddable, OneOrMany}; +use crate::{tool::ToolEmbeddingDyn, ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -use super::embeddable::EmbeddableError; +use super::extract_embedding_fields::ExtractEmbeddingFieldsError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] @@ -11,20 +11,80 @@ pub struct EmbeddableTool { pub embedding_docs: Vec, } -impl Embeddable for EmbeddableTool { - type Error = EmbeddableError; +impl ExtractEmbeddingFields for EmbeddableTool { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { - OneOrMany::many(self.embedding_docs.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result, Self::Error> { + OneOrMany::many(self.embedding_docs.clone()).map_err(ExtractEmbeddingFieldsError::new) } } impl EmbeddableTool { - /// Convert item that implements ToolEmbedding to an EmbeddableTool. - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { + /// Convert item that implements ToolEmbeddingDyn to an EmbeddableTool. + /// # Example + /// ```rust + /// use rig::{ + /// completion::ToolDefinition, + /// embeddings::EmbeddableTool, + /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, + /// }; + /// use serde_json::json; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Math error")] + /// struct NothingError; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Init error")] + /// struct InitError; + /// + /// struct Nothing; + /// impl Tool for Nothing { + /// const NAME: &'static str = "nothing"; + /// + /// type Error = NothingError; + /// type Args = (); + /// type Output = (); + /// + /// async fn definition(&self, _prompt: String) -> ToolDefinition { + /// serde_json::from_value(json!({ + /// "name": "nothing", + /// "description": "nothing", + /// "parameters": {} + /// })) + /// .expect("Tool Definition") + /// } + /// + /// async fn call(&self, args: Self::Args) -> Result { + /// Ok(()) + /// } + /// } + /// + /// impl ToolEmbedding for Nothing { + /// type InitError = InitError; + /// type Context = (); + /// type State = (); + /// + /// fn init(_state: Self::State, _context: Self::Context) -> Result { + /// Ok(Nothing) + /// } + /// + /// fn embedding_docs(&self) -> Vec { + /// vec!["Do nothing.".into()] + /// } + /// + /// fn context(&self) -> Self::Context {} + /// } + /// + /// let tool = EmbeddableTool::try_from(&Nothing).unwrap(); + /// + /// assert_eq!(tool.name, "nothing".to_string()); + /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); + /// ``` + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { Ok(EmbeddableTool { name: tool.name(), - context: tool.context().map_err(EmbeddableError::new)?, + context: tool.context().map_err(ExtractEmbeddingFieldsError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 07c59f96..5383b34e 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,8 +78,8 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::embeddable::Embeddable; -pub use one_or_many::OneOrMany; +pub use embeddings::ExtractEmbeddingFields; +pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] -pub use rig_derive::Embeddable; +pub use rig_derive::ExtractEmbeddingFields; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 8f8eefd4..a6d8f00b 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; @@ -85,7 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings( + pub fn embeddings( &self, model: &str, input_type: &str, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index b20df22f..789c5282 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e92896b8..528faba5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, + embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, tool::EmbeddableTool}, }; #[derive(Debug, thiserror::Error)] @@ -330,7 +330,7 @@ impl ToolSet { /// Convert tools in self to objects of type EmbeddableTool. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result, EmbeddableError> { + pub fn embedabble_tools(&self) -> Result, ExtractEmbeddingFieldsError> { self.tools .values() .filter_map(|tool_type| { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 208b5f13..f4f067fe 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,6 +76,21 @@ impl InMemoryVectorStore { Ok(self) } + /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. + /// Returns the store with the added documents. + pub fn add_documents_with_id( + mut self, + documents: Vec<(D, OneOrMany)>, + id_f: fn(&D) -> String, + ) -> Result { + for (doc, embeddings) in documents { + let id = id_f(&doc); + self.embeddings.insert(id, (doc, embeddings)); + } + + Ok(self) + } + /// Get the document by its id and deserialize it into the given type. pub fn get_document Deserialize<'a>>( &self, diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 38e45d0e..044d8c2a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,6 +17,9 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box), + + #[error("Missing Id: {0}")] + MissingIdError(String), } /// Trait for vector store indexes diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/extract_embedding_fields_macro.rs similarity index 83% rename from rig-core/tests/embeddable_macro.rs rename to rig-core/tests/extract_embedding_fields_macro.rs index cbc76c80..5f8891ff 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/extract_embedding_fields_macro.rs @@ -1,14 +1,14 @@ -use rig::embeddings::embeddable::EmbeddableError; -use rig::{Embeddable, OneOrMany}; +use rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -fn serialize(definition: Definition) -> Result, EmbeddableError> { +fn serialize(definition: Definition) -> Result, ExtractEmbeddingFieldsError> { Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(EmbeddableError::new)?, + serde_json::to_string(&definition).map_err(ExtractEmbeddingFieldsError::new)?, )) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition { id: String, word: String, @@ -41,7 +41,7 @@ fn test_custom_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one( "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -49,7 +49,7 @@ fn test_custom_embed() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition2 { id: String, #[embed] @@ -76,17 +76,17 @@ fn test_custom_and_basic_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap().first(), + fake_definition.extract_embedding_fields().unwrap().first(), "house".to_string() ); assert_eq!( - fake_definition.embeddable().unwrap().rest(), + fake_definition.extract_embedding_fields().unwrap().rest(), vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition3 { id: String, word: String, @@ -109,12 +109,12 @@ fn test_single_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one(definition) ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company { id: String, company: String, @@ -132,7 +132,7 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let result = company.embeddable().unwrap(); + let result = company.extract_embedding_fields().unwrap(); assert_eq!( result, @@ -153,7 +153,7 @@ fn test_multiple_embed_strings() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company2 { id: String, #[embed] @@ -173,7 +173,7 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); assert_eq!( - company.embeddable().unwrap(), + company.extract_embedding_fields().unwrap(), OneOrMany::many(vec![ "Google".to_string(), "25".to_string(), diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index d6e02a5a..1a9089a5 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::Embedding; -use rig::{Embeddable, OneOrMany}; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Deserialize; -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] pub struct FakeDefinition { pub id: String, #[embed] diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..7ffd6b12 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,12 +3,12 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -40,12 +40,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -61,10 +62,10 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default(); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index - let results = vector_store + let results = vector_store_index .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..859442be 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,11 +3,11 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -33,17 +33,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..824deda0 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,11 +4,11 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -46,12 +46,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -73,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 18d87d7f..2eea2357 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -20,79 +20,17 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } +/// Type on which vector searches can be performed for a lanceDb table. /// # Example /// ``` -/// use std::{env, sync::Arc}; - -/// use arrow_array::RecordBatchIterator; -/// use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; -/// use lancedb::index::vector::IvfPqIndexBuilder; -/// use rig::vector_store::VectorStoreIndex; -/// use rig::{ -/// embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, -/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// }; -/// use rig_lancedb::{LanceDbVectorStore, SearchParams}; -/// -/// #[path = "../examples/fixtures/lib.rs"] -/// mod fixture; -/// -/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); -/// -/// // Select an embedding model. -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -/// -/// // Initialize LanceDB locally. -/// let db = lancedb::connect("data/lancedb-store").execute().await?; -/// -/// // Generate embeddings for the test data. -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions())? -/// // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. -/// .documents( -/// (0..256) -/// .map(|i| FakeDefinition { -/// id: format!("doc{}", i), -/// definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() -/// }) -/// .collect(), -/// )? -/// .build() -/// .await?; +/// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +/// use rig::embeddings::EmbeddingModel; /// -/// // Create table with embeddings. -/// let record_batch = as_record_batch(embeddings, model.ndims()); -/// let table = db -/// .create_table( -/// "definitions", -/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), -/// ) -/// .execute() -/// .await?; -/// -/// // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information -/// table -/// .create_index( -/// &["embedding"], -/// lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), -/// ) -/// .execute() -/// .await?; -/// -/// // Define search_params params that will be used by the vector store to perform the vector search. -/// let search_params = SearchParams::default(); -/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; -/// -/// // Query the index -/// let results = vector_store -/// .top_n::("My boss says I zindle too much, what does that mean?", 1) -/// .await?; -/// -/// println!("Results: {:?}", results); +/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` -pub struct LanceDbVectorStore { +pub struct LanceDbVectorIndex { /// Defines which model is used to generate embeddings for the vector store. model: M, /// LanceDB table containing embeddings. @@ -103,7 +41,24 @@ pub struct LanceDbVectorStore { search_params: SearchParams, } -impl LanceDbVectorStore { +impl LanceDbVectorIndex { + /// Create an instance of `LanceDbVectorIndex` with an existing table and model. + /// Define the id field name of the table. + /// Define search parameters that will be used to perform vector searches on the table. + pub async fn new( + table: lancedb::Table, + model: M, + id_field: &str, + search_params: SearchParams, + ) -> Result { + Ok(Self { + table, + model, + id_field: id_field.to_string(), + search_params, + }) + } + /// Apply the search_params to the vector query. /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait. fn build_query(&self, mut query: VectorQuery) -> VectorQuery { @@ -155,6 +110,10 @@ pub enum SearchType { } /// Parameters used to perform a vector search on a LanceDb table. +/// # Example +/// ``` +/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine); +/// ``` #[derive(Debug, Clone, Default)] pub struct SearchParams { distance_type: Option, @@ -215,26 +174,22 @@ impl SearchParams { } } -impl LanceDbVectorStore { - /// Create an instance of `LanceDbVectorStore` with an existing table and model. - /// Define the id field name of the table. - /// Define search parameters that will be used to perform vector searches on the table. - pub async fn new( - table: lancedb::Table, - model: M, - id_field: &str, - search_params: SearchParams, - ) -> Result { - Ok(Self { - table, - model, - id_field: id_field.to_string(), - search_params, - }) - } -} - -impl VectorStoreIndex for LanceDbVectorStore { +impl VectorStoreIndex for LanceDbVectorIndex { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n Deserialize<'a> + Send>( &self, query: &str, @@ -269,6 +224,18 @@ impl VectorStoreIndex for LanceDbVectorStore .collect() } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n_ids( &self, query: &str, diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index b095c060..cc16d4bc 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,7 +3,7 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::Embeddable; +use rig::ExtractEmbeddingFields; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; @@ -11,7 +11,7 @@ use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] struct FakeDefinition { #[serde(rename = "_id")] id: String, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 3201c648..655f3939 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,134 +7,35 @@ use rig::{ }; use serde::Deserialize; +fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + /// # Example /// ``` -/// use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -/// use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -/// use serde::{Deserialize, Serialize}; -/// use std::env; - -/// use rig::Embeddable; -/// use rig::{ -/// embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, -/// }; /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; - -/// // Shape of data that needs to be RAG'ed. -/// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug)] -/// struct FakeDefinition { -/// #[serde(rename = "_id")] -/// id: String, -/// #[embed] -/// definition: String, -/// } - -/// #[derive(Clone, Deserialize, Debug, Serialize)] -/// struct Link { -/// word: String, -/// link: String, -/// } - -/// // Shape of the document to be stored in MongoDB, with embeddings. -/// #[derive(Serialize, Debug)] +/// use rig::embeddings::EmbeddingModel; +/// +/// #[derive(serde::Serialize, Debug)] /// struct Document { /// #[serde(rename = "_id")] /// id: String, /// definition: String, /// embedding: Vec, /// } -/// // Initialize OpenAI client -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); - -/// // Initialize MongoDB client -/// let mongodb_connection_string = -/// env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); -/// let options = ClientOptions::parse(mongodb_connection_string) -/// .await -/// .expect("MongoDB connection string should be valid"); - -/// let mongodb_client = -/// MongoClient::with_options(options).expect("MongoDB client options should be valid"); - -/// // Initialize MongoDB vector store -/// let collection: Collection = mongodb_client -/// .database("knowledgebase") -/// .collection("context"); - -/// // Select the embedding model and generate our embeddings -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - -/// let fake_definitions = vec![ -/// FakeDefinition { -/// id: "doc0".to_string(), -/// definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc1".to_string(), -/// definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc2".to_string(), -/// definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), -/// } -/// ]; - -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions)? -/// .build() -/// .await?; - -/// let mongo_documents = embeddings -/// .iter() -/// .map( -/// |(FakeDefinition { id, definition, .. }, embedding)| Document { -/// id: id.clone(), -/// definition: definition.clone(), -/// embedding: embedding.first().vec.clone(), -/// }, -/// ) -/// .collect::>(); - -/// match collection.insert_many(mongo_documents, None).await { -/// Ok(_) => println!("Documents added successfully"), -/// Err(e) => println!("Error adding documents: {:?}", e), -/// }; - -/// // Create a vector index on our vector store. -/// // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. -/// // IMPORTANT: Reuse the same model that was used to generate the embeddings +/// +/// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// let index = MongoDbVectorStore::new(collection).index( /// model, -/// "vector_index", -/// SearchParams::new("embedding"), +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. /// ); - -/// // Query the index -/// let results = index -/// .top_n::("What is a linglingdong?", 1) -/// .await?; - -/// println!("Results: {:?}", results); - -/// let id_results = index -/// .top_n_ids("What is a linglingdong?", 1) -/// .await? -/// .into_iter() -/// .map(|(score, id)| (score, id)) -/// .collect::>(); - -/// println!("ID results: {:?}", id_results); /// ``` pub struct MongoDbVectorStore { collection: mongodb::Collection, } -fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. pub fn new(collection: mongodb::Collection) -> Self { @@ -263,6 +164,41 @@ impl SearchParams { impl VectorStoreIndex for MongoDbVectorIndex { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// #[derive(serde::Deserialize, Debug)] + /// struct Definition { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// } + /// + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. + /// + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n Deserialize<'a> + Send>( &self, query: &str, @@ -303,6 +239,33 @@ impl VectorStoreIndex Ok(results) } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n_ids( &self, query: &str,