From a781fc99c22e7d682c59db4103577c17c67d9a2b Mon Sep 17 00:00:00 2001 From: Christophe Date: Wed, 6 Nov 2024 14:21:00 -0500 Subject: [PATCH 1/9] refactor: Big refactor --- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 7 +- rig-core/examples/test.rs | 76 ++++++++ rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/completion.rs | 2 +- rig-core/src/embeddings/builder.rs | 138 +++++++-------- rig-core/src/embeddings/embed.rs | 163 ++++++++++++++++++ rig-core/src/embeddings/embedding.rs | 60 +++++-- .../embeddings/extract_embedding_fields.rs | 163 ------------------ rig-core/src/embeddings/mod.rs | 4 +- rig-core/src/embeddings/tool.rs | 19 +- rig-core/src/lib.rs | 2 +- rig-core/src/providers/cohere.rs | 19 +- rig-core/src/providers/openai.rs | 9 +- rig-core/src/tool.rs | 4 +- rig-core/src/vector_store/in_memory_store.rs | 4 +- rig-core/src/vector_store/mod.rs | 2 +- .../tests/extract_embedding_fields_macro.rs | 4 +- rig-lancedb/examples/fixtures/lib.rs | 2 +- rig-lancedb/src/lib.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- rig-mongodb/src/lib.rs | 4 +- 23 files changed, 399 insertions(+), 295 deletions(-) create mode 100644 rig-core/examples/test.rs create mode 100644 rig-core/src/embeddings/embed.rs delete mode 100644 rig-core/src/embeddings/extract_embedding_fields.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 9829089f..3877e651 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -5,7 +5,7 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - ExtractEmbeddingFields, + Embed, }; use serde::Serialize; diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index c140da15..c72276c7 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::EmbeddingsBuilder, + embeddings::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -155,9 +155,8 @@ async fn main() -> Result<(), anyhow::Error> { .dynamic_tool(Subtract) .build(); - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? - .build() + let embeddings = embedding_model + .embed_many(toolset.embedabble_tools()?) .await?; let index = InMemoryVectorStore::default() diff --git a/rig-core/examples/test.rs b/rig-core/examples/test.rs new file mode 100644 index 00000000..7eb30779 --- /dev/null +++ b/rig-core/examples/test.rs @@ -0,0 +1,76 @@ +use std::env; + +use rig::{ + embeddings::{embed::EmbedError, Embedding, EmbeddingError, EmbeddingModel, EmbeddingsBuilder}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embed, OneOrMany, +}; +use serde::{Deserialize, Serialize}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + word: String, + // #[embed] + definitions: Vec, +} + +impl Embed for FakeDefinition { + fn embed(&self, embedder: &mut rig::embeddings::Embedder) -> Result<(), EmbedError> { + for doc in &self.definitions { + embedder.embed(doc.clone()); + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + + let docs = vec![ + FakeDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), + "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ]; + + let embeddings = model.embed_many(docs).await?; + + let data = vec![ + "What is a flurbo?", + "What is a glarb-glarb?", + "What is a linglingdong?", + ]; + + let embeddings = model.embed_many(data).await?; + + Ok(()) +} diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 4777b42c..089822d3 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -4,7 +4,7 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - ExtractEmbeddingFields, + Embed, }; use serde::{Deserialize, Serialize}; diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6d966004..bebe87e4 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -4,7 +4,7 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - ExtractEmbeddingFields, + Embed, }; use serde::{Deserialize, Serialize}; diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index e766fb27..f13f316b 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -82,7 +82,7 @@ pub enum CompletionError { /// Error building the completion request #[error("RequestError: {0}")] - RequestError(#[from] Box), + RequestError(#[from] Box), /// Error parsing the completion response #[error("ResponseError: {0}")] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 68d9e78b..cb85d49e 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -3,20 +3,22 @@ use std::{cmp::max, collections::HashMap}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, StreamExt}; use crate::{ - embeddings::{Embedding, EmbeddingError, EmbeddingModel, ExtractEmbeddingFields}, + embeddings::{Embed, Embedding, EmbeddingError, EmbeddingModel}, OneOrMany, }; +use super::{embed::EmbedError, Embedder}; + /// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { model: M, documents: Vec<(T, OneOrMany)>, } -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -26,22 +28,23 @@ impl EmbeddingsBuilder { } /// Add a document that implements `ExtractEmbeddingFields` to the builder. - pub fn document(mut self, document: T) -> Result { - let embed_targets = document.extract_embedding_fields()?; + pub fn document(mut self, document: T) -> Result { + let mut embedder = Embedder::default(); + document.embed(&mut embedder)?; + + self.documents + .push((document, OneOrMany::many(embedder.texts).unwrap())); - self.documents.push((document, embed_targets)); Ok(self) } /// Add many documents that implement `ExtractEmbeddingFields` to the builder. - pub fn documents(mut self, documents: Vec) -> Result { - for doc in documents.into_iter() { - let embed_targets = doc.extract_embedding_fields()?; - - self.documents.push((doc, embed_targets)); - } + pub fn documents(self, documents: impl IntoIterator) -> Result { + let builder = documents + .into_iter() + .try_fold(self, |builder, doc| builder.document(doc))?; - Ok(self) + Ok(builder) } } @@ -103,45 +106,38 @@ impl EmbeddingsBuilder { /// .build() /// .await?; /// ``` -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). pub async fn build(self) -> Result)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. - let documents_map = self - .documents - .clone() - .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) - .collect::>(); - - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter( - embed_targets - .clone() - .into_iter() - .map(move |target| (i, target)), - ) - }) - // Chunk them into N (the emebdding API limit per request). + use stream::TryStreamExt; + + let mut docs = HashMap::new(); + let mut texts = HashMap::new(); + + // Gather the texts to embed for each document + for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() { + docs.insert(i, doc); + texts.insert(i, doc_texts); + } + + // Compute the embeddings + let mut embeddings = stream::iter(texts.into_iter()) + // Merge the texts of each document into a single list of texts. + .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text)))) + // Chunk them into batches the embedding API limit per request. .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. + // Generate the embeddings for each batch. .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - - Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::>(), - ) + let embeddings = self + .model + .embed_texts(docs.into_iter().map(|(_, text)| text)) + .await?; + Ok::<_, EmbeddingError>(embeddings.into_iter().enumerate().collect::>()) }) - .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + // Collect the embeddings into a HashMap. .try_fold( HashMap::new(), |mut acc: HashMap<_, OneOrMany>, embeddings| async move { @@ -154,27 +150,26 @@ impl Embeddi Ok(acc) }, ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); - - Ok(embeddings) + .await?; + + // Merge the embeddings with their respective documents + Ok(docs + .into_iter() + .map(|(i, doc)| { + ( + doc, + embeddings.remove(&i).expect("Document should be present"), + ) + }) + .collect()) } } #[cfg(test)] mod tests { use crate::{ - embeddings::{ - extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel, - }, - ExtractEmbeddingFields, + embeddings::{embed::EmbedError, Embedder, Embedding, EmbeddingModel}, + Embed, }; use super::EmbeddingsBuilder; @@ -189,7 +184,7 @@ mod tests { 10 } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator + Send, ) -> Result, crate::embeddings::EmbeddingError> { @@ -209,12 +204,12 @@ mod tests { definitions: Vec, } - impl ExtractEmbeddingFields for FakeDefinition { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()) - .map_err(ExtractEmbeddingFieldsError::new) + impl Embed for FakeDefinition { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + for definition in &self.definitions { + embedder.embed(definition.clone()); + } + Ok(()) } } @@ -256,11 +251,10 @@ mod tests { definition: String, } - impl ExtractEmbeddingFields for FakeDefinitionSingle { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(crate::OneOrMany::one(self.definition.clone())) + impl Embed for FakeDefinitionSingle { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.definition.clone()); + Ok(()) } } diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs new file mode 100644 index 00000000..cda02d33 --- /dev/null +++ b/rig-core/src/embeddings/embed.rs @@ -0,0 +1,163 @@ +//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. + +// use crate::one_or_many::OneOrMany; + +use super::EmbeddingModel; + +/// Error type used for when the `extract_embedding_fields` method fails. +/// Used by default implementations of `ExtractEmbeddingFields` for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct EmbedError(#[from] Box); + +impl EmbedError { + pub fn new(error: E) -> Self { + EmbedError(Box::new(error)) + } +} + +/// Derive this trait for structs whose fields need to be converted to vector embeddings. +/// The `extract_embedding_fields` method returns a `OneOrMany`. This function extracts the fields that need to be embedded and returns them as a list of strings. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl ExtractEmbeddingFields for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn extract_embedding_fields(&self) -> Result, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` +pub trait Embed { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError>; +} + +#[derive(Default)] +pub struct Embedder { + pub texts: Vec, +} + +impl Embedder { + pub fn embed(&mut self, text: String) { + self.texts.push(text); + } +} + +// ================================================================ +// Implementations of ExtractEmbeddingFields for common types +// ================================================================ +impl Embed for String { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.clone()); + Ok(()) + } +} + +impl Embed for &str { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i8 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i16 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i32 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i64 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i128 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f32 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f64 { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for bool { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for char { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for serde_json::Value { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?); + Ok(()) + } +} + +impl Embed for Vec { + fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + for item in self { + item.embed(embedder).map_err(EmbedError::new)?; + } + Ok(()) + } +} diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 47820a81..8748c6f1 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -10,6 +10,10 @@ use serde::{Deserialize, Serialize}; +use crate::OneOrMany; + +use super::{Embed, EmbeddingsBuilder}; + #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) @@ -22,7 +26,7 @@ pub enum EmbeddingError { /// Error processing the document for embedding #[error("DocumentError: {0}")] - DocumentError(String), + DocumentError(Box), /// Error parsing the completion response #[error("ResponseError: {0}")] @@ -41,29 +45,59 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// The number of dimensions in the embedding vector. fn ndims(&self) -> usize; - /// Embed a single document - fn embed_document( + /// Embed multiple text documents in a single request + fn embed_texts( + &self, + documents: impl IntoIterator + Send, + ) -> impl std::future::Future, EmbeddingError>> + Send; + + /// Embed a single text document + fn embed_text( &self, document: &str, - ) -> impl std::future::Future> + Send - where - Self: Sync, + ) -> impl std::future::Future> + Send { + async { + Ok(self + .embed_texts(vec![document.to_string()]) + .await? + .pop() + .expect("There should be at least one embedding")) + } + } + + /// Embed a single document + fn embed( + &self, + document: T, + ) -> impl std::future::Future, EmbeddingError>> + Send { async { Ok(self - .embed_documents(vec![document.to_string()]) + .embed_many(vec![document]) .await? - .first() - .cloned() - .expect("One embedding should be present")) + .pop() + .map(|(_, embedding)| embedding) + .expect("There should be at least one embedding")) } } /// Embed multiple documents in a single request - fn embed_documents( + fn embed_many + Send>( &self, - documents: impl IntoIterator + Send, - ) -> impl std::future::Future, EmbeddingError>> + Send; + documents: I, + ) -> impl std::future::Future)>, EmbeddingError>> + Send + where + ::IntoIter: std::marker::Send, + { + async { + let builder = EmbeddingsBuilder::new(self.clone()); + builder + .documents(documents) + .map_err(|err| EmbeddingError::DocumentError(Box::new(err)))? + .build() + .await + } + } } /// Struct that holds a single document and its embedding. diff --git a/rig-core/src/embeddings/extract_embedding_fields.rs b/rig-core/src/embeddings/extract_embedding_fields.rs deleted file mode 100644 index e62d43c6..00000000 --- a/rig-core/src/embeddings/extract_embedding_fields.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. - -use crate::one_or_many::OneOrMany; - -/// Error type used for when the `extract_embedding_fields` method fails. -/// Used by default implementations of `ExtractEmbeddingFields` for common types. -#[derive(Debug, thiserror::Error)] -#[error("{0}")] -pub struct ExtractEmbeddingFieldsError(#[from] Box); - -impl ExtractEmbeddingFieldsError { - pub fn new(error: E) -> Self { - ExtractEmbeddingFieldsError(Box::new(error)) - } -} - -/// Derive this trait for structs whose fields need to be converted to vector embeddings. -/// The `extract_embedding_fields` method returns a `OneOrMany`. This function extracts the fields that need to be embedded and returns them as a list of strings. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. -/// # Example -/// ```rust -/// use std::env; -/// -/// use serde::{Deserialize, Serialize}; -/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; -/// -/// struct FakeDefinition { -/// id: String, -/// word: String, -/// definitions: String, -/// } -/// -/// let fake_definition = FakeDefinition { -/// id: "doc1".to_string(), -/// word: "rock".to_string(), -/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() -/// }; -/// -/// impl ExtractEmbeddingFields for FakeDefinition { -/// type Error = EmptyListError; -/// -/// fn extract_embedding_fields(&self) -> Result, Self::Error> { -/// // Embeddings only need to be generated for `definition` field. -/// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. -/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); -/// -/// OneOrMany::many(definitions) -/// } -/// } -/// ``` -pub trait ExtractEmbeddingFields { - type Error: std::error::Error + Sync + Send + 'static; - - fn extract_embedding_fields(&self) -> Result, Self::Error>; -} - -// ================================================================ -// Implementations of ExtractEmbeddingFields for common types -// ================================================================ -impl ExtractEmbeddingFields for String { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.clone())) - } -} - -impl ExtractEmbeddingFields for i8 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i16 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i32 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i64 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i128 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for f32 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for f64 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for bool { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for char { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for serde_json::Value { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - Ok(OneOrMany::one( - serde_json::to_string(self).map_err(ExtractEmbeddingFieldsError::new)?, - )) - } -} - -impl ExtractEmbeddingFields for Vec { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - let items = self - .iter() - .map(|item| item.extract_embedding_fields()) - .collect::, _>>() - .map_err(ExtractEmbeddingFieldsError::new)?; - - OneOrMany::merge(items).map_err(ExtractEmbeddingFieldsError::new) - } -} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37323cf5..f008c050 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; +pub mod embed; pub mod embedding; -pub mod extract_embedding_fields; pub mod tool; pub use builder::EmbeddingsBuilder; +pub use embed::{Embed, Embedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; -pub use extract_embedding_fields::ExtractEmbeddingFields; pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 7550038f..1a2e78a6 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,7 +1,7 @@ -use crate::{tool::ToolEmbeddingDyn, ExtractEmbeddingFields, OneOrMany}; +use crate::{tool::ToolEmbeddingDyn, Embed}; use serde::Serialize; -use super::extract_embedding_fields::ExtractEmbeddingFieldsError; +use super::embed::EmbedError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] @@ -11,11 +11,12 @@ pub struct EmbeddableTool { pub embedding_docs: Vec, } -impl ExtractEmbeddingFields for EmbeddableTool { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { - OneOrMany::many(self.embedding_docs.clone()).map_err(ExtractEmbeddingFieldsError::new) +impl Embed for EmbeddableTool { + fn embed(&self, embedder: &mut super::embed::Embedder) -> Result<(), EmbedError> { + for doc in &self.embedding_docs { + embedder.embed(doc.clone()); + } + Ok(()) } } @@ -81,10 +82,10 @@ impl EmbeddableTool { /// assert_eq!(tool.name, "nothing".to_string()); /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); /// ``` - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { Ok(EmbeddableTool { name: tool.name(), - context: tool.context().map_err(ExtractEmbeddingFieldsError::new)?, + context: tool.context().map_err(EmbedError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index edc5cab2..c3f4fcf8 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -80,7 +80,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::ExtractEmbeddingFields; +pub use embeddings::Embed; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ef449dd6..4299684e 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, ExtractEmbeddingFields, + json_utils, Embed, }; use schemars::JsonSchema; @@ -92,7 +92,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings( + pub fn embeddings( &self, model: &str, input_type: &str, @@ -201,7 +201,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { @@ -222,11 +222,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { match response.json::>().await? { ApiResponse::Ok(response) => { if response.embeddings.len() != documents.len() { - return Err(EmbeddingError::DocumentError(format!( - "Expected {} embeddings, got {}", - documents.len(), - response.embeddings.len() - ))); + return Err(EmbeddingError::DocumentError( + format!( + "Expected {} embeddings, got {}", + documents.len(), + response.embeddings.len() + ) + .into(), + )); } Ok(response diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index ab1517d8..828f27bc 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, ExtractEmbeddingFields, + json_utils, Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,10 +121,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( - &self, - model: &str, - ) -> EmbeddingsBuilder { + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model)) } @@ -242,7 +239,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 528faba5..f8d7d7a8 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, tool::EmbeddableTool}, + embeddings::{embed::EmbedError, tool::EmbeddableTool}, }; #[derive(Debug, thiserror::Error)] @@ -330,7 +330,7 @@ impl ToolSet { /// Convert tools in self to objects of type EmbeddableTool. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result, ExtractEmbeddingFieldsError> { + pub fn embedabble_tools(&self) -> Result, EmbedError> { self.tools .values() .filter_map(|tool_type| { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index f4f067fe..4519f45a 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -171,7 +171,7 @@ impl VectorStoreIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); @@ -192,7 +192,7 @@ impl VectorStoreIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 044d8c2a..b2b8c93e 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -16,7 +16,7 @@ pub enum VectorStoreError { JsonError(#[from] serde_json::Error), #[error("Datastore error: {0}")] - DatastoreError(#[from] Box), + DatastoreError(#[from] Box), #[error("Missing Id: {0}")] MissingIdError(String), diff --git a/rig-core/tests/extract_embedding_fields_macro.rs b/rig-core/tests/extract_embedding_fields_macro.rs index 5f8891ff..a5db1e24 100644 --- a/rig-core/tests/extract_embedding_fields_macro.rs +++ b/rig-core/tests/extract_embedding_fields_macro.rs @@ -1,5 +1,5 @@ -use rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; -use rig::{ExtractEmbeddingFields, OneOrMany}; +use rig::embeddings::embed::ExtractEmbeddingFieldsError; +use rig::{Embed, OneOrMany}; use serde::Serialize; fn serialize(definition: Definition) -> Result, ExtractEmbeddingFieldsError> { diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 1a9089a5..d6dc71e6 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::Embedding; -use rig::{ExtractEmbeddingFields, OneOrMany}; +use rig::{Embed, OneOrMany}; use serde::Deserialize; #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 2eea2357..0d06b495 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -195,7 +195,7 @@ impl VectorStoreIndex for LanceDbVectorIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table @@ -241,7 +241,7 @@ impl VectorStoreIndex for LanceDbVectorIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cc16d4bc..cdd7de7d 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,7 +3,7 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::ExtractEmbeddingFields; +use rig::Embed; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 655f3939..fd440669 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -204,7 +204,7 @@ impl VectorStoreIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection @@ -271,7 +271,7 @@ impl VectorStoreIndex query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection From be15036cf4943c5ce03b2ae708a85dd9b34340ea Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 6 Nov 2024 18:00:32 -0500 Subject: [PATCH 2/9] refactor: refactor Embed trait, fix all imports, rename files, fix macro --- rig-core/Cargo.toml | 2 +- rig-core/examples/rag.rs | 4 +- rig-core/examples/test.rs | 76 -------------- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/rig-core-derive/src/basic.rs | 4 +- .../{extract_embedding_fields.rs => embed.rs} | 65 ++++-------- rig-core/rig-core-derive/src/lib.rs | 6 +- rig-core/src/embeddings/builder.rs | 98 ++++++++++++------- rig-core/src/embeddings/embed.rs | 81 ++++++++------- rig-core/src/embeddings/mod.rs | 2 +- rig-core/src/embeddings/tool.rs | 2 +- rig-core/src/lib.rs | 2 +- ...bedding_fields_macro.rs => embed_macro.rs} | 80 ++++++++------- rig-lancedb/examples/fixtures/lib.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 5 +- 16 files changed, 180 insertions(+), 253 deletions(-) delete mode 100644 rig-core/examples/test.rs rename rig-core/rig-core-derive/src/{extract_embedding_fields.rs => embed.rs} (60%) rename rig-core/tests/{extract_embedding_fields_macro.rs => embed_macro.rs} (64%) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 514bdca5..c705ba06 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,7 +35,7 @@ tokio-test = "0.4.4" derive = ["dep:rig-derive"] [[test]] -name = "extract_embedding_fields_macro" +name = "embed_macro" required-features = ["derive"] [[example]] diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3877e651..0e8d5a06 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -10,9 +10,9 @@ use rig::{ use serde::Serialize; // Shape of data that needs to be RAG'ed. -// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition` +// A vector search needs to be performed on the definitions, so we derive the `Embed` trait for `FakeDefinition` // and tag that field with `#[embed]`. -#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)] +#[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/examples/test.rs b/rig-core/examples/test.rs deleted file mode 100644 index 7eb30779..00000000 --- a/rig-core/examples/test.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::env; - -use rig::{ - embeddings::{embed::EmbedError, Embedding, EmbeddingError, EmbeddingModel, EmbeddingsBuilder}, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embed, OneOrMany, -}; -use serde::{Deserialize, Serialize}; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - word: String, - // #[embed] - definitions: Vec, -} - -impl Embed for FakeDefinition { - fn embed(&self, embedder: &mut rig::embeddings::Embedder) -> Result<(), EmbedError> { - for doc in &self.definitions { - embedder.embed(doc.clone()); - } - Ok(()) - } -} - -#[tokio::main] -async fn main() -> Result<(), anyhow::Error> { - // Create OpenAI client - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); - - let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - - let docs = vec![ - FakeDefinition { - id: "doc0".to_string(), - word: "flurbo".to_string(), - definitions: vec![ - "A green alien that lives on cold planets.".to_string(), - "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - word: "glarb-glarb".to_string(), - definitions: vec![ - "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - word: "linglingdong".to_string(), - definitions: vec![ - "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), - "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ]; - - let embeddings = model.embed_many(docs).await?; - - let data = vec![ - "What is a flurbo?", - "What is a glarb-glarb?", - "What is a linglingdong?", - ]; - - let embeddings = model.embed_many(data).await?; - - Ok(()) -} diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 089822d3..925ebca8 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index bebe87e4..f3a97498 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 39b72018..b0e5afd7 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator syn::Resu let (custom_targets, custom_target_size) = data_struct.custom()?; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement `ExtractEmbeddingFields` trait for the struct. + // ie. do not implement `Embed` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( name, @@ -27,14 +27,14 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } quote! { - let mut embed_targets = #basic_targets; - embed_targets.extend(#custom_targets) + #basic_targets; + #custom_targets; } } _ => { return Err(syn::Error::new_spanned( input, - "ExtractEmbeddingFields derive macro should only be used on structs", + "Embed derive macro should only be used on structs", )) } }; @@ -42,18 +42,13 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: `ExtractEmbeddingFields` trait is imported with the macro. + // Note: `Embed` trait is imported with the macro. - impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { - type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result, Self::Error> { + impl #impl_generics Embed for #name #ty_generics #where_clause { + fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> { #target_stream; - rig::OneOrMany::merge( - embed_targets.into_iter() - .collect::, _>>()? - ).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new) + Ok(()) } } }; @@ -84,21 +79,12 @@ impl StructParser for DataStruct { }) .collect::>(); - if !embed_targets.is_empty() { - ( - quote! { - vec![#(#embed_targets.extract_embedding_fields()),*] - }, - embed_targets.len(), - ) - } else { - ( - quote! { - vec![] - }, - 0, - ) - } + ( + quote! { + #(#embed_targets.embed(embedder)?;)* + }, + embed_targets.len(), + ) } fn custom(&self) -> syn::Result<(TokenStream, usize)> { @@ -109,25 +95,16 @@ impl StructParser for DataStruct { let field_name = &field.ident; quote! { - #custom_func_path(self.#field_name.clone()) + #custom_func_path(embedder, self.#field_name.clone())?; } }) .collect::>(); - Ok(if !embed_targets.is_empty() { - ( - quote! { - vec![#(#embed_targets),*] - }, - embed_targets.len(), - ) - } else { - ( - quote! { - vec![] - }, - 0, - ) - }) + Ok(( + quote! { + #(#embed_targets)* + }, + embed_targets.len(), + )) } } diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 8ad69a65..6d8b18dc 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -4,18 +4,18 @@ use syn::{parse_macro_input, DeriveInput}; mod basic; mod custom; -mod extract_embedding_fields; +mod embed; pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))] +#[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - extract_embedding_fields::expand_derive_embedding(&mut input) + embed::expand_derive_embedding(&mut input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index cb85d49e..edf7d2e5 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,21 +1,20 @@ //! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [ExtractEmbeddingFields] trait can be added to the [EmbeddingsBuilder]. +//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt}; use crate::{ - embeddings::{Embed, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, TextEmbedder}, OneOrMany, }; -use super::{embed::EmbedError, Embedder}; - -/// Builder for creating a collection of embeddings. +/// Builder for creating a collection of embeddings from a vector of documents of type `T`. +/// Accumulate documents such that they can be embedded in a single batch to limit api calls to the provider. pub struct EmbeddingsBuilder { model: M, - documents: Vec<(T, OneOrMany)>, + documents: Vec<(T, Vec)>, } impl EmbeddingsBuilder { @@ -27,18 +26,17 @@ impl EmbeddingsBuilder { } } - /// Add a document that implements `ExtractEmbeddingFields` to the builder. + /// Add a document that implements `Embed` to the builder. pub fn document(mut self, document: T) -> Result { - let mut embedder = Embedder::default(); + let mut embedder = TextEmbedder::default(); document.embed(&mut embedder)?; - self.documents - .push((document, OneOrMany::many(embedder.texts).unwrap())); + self.documents.push((document, embedder.texts)); Ok(self) } - /// Add many documents that implement `ExtractEmbeddingFields` to the builder. + /// Add many documents that implement `Embed` to the builder. pub fn documents(self, documents: impl IntoIterator) -> Result { let builder = documents .into_iter() @@ -56,13 +54,13 @@ impl EmbeddingsBuilder { /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// ExtractEmbeddingFields, +/// Embed, /// }; /// use serde::{Deserialize, Serialize}; /// /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. -/// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] /// struct FakeDefinition { /// id: String, /// word: String, @@ -112,10 +110,10 @@ impl EmbeddingsBuilder { pub async fn build(self) -> Result)>, EmbeddingError> { use stream::TryStreamExt; + // Store the documents and their texts in a HashMap for easy access let mut docs = HashMap::new(); let mut texts = HashMap::new(); - // Gather the texts to embed for each document for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() { docs.insert(i, doc); texts.insert(i, doc_texts); @@ -128,12 +126,11 @@ impl EmbeddingsBuilder { // Chunk them into batches the embedding API limit per request. .chunks(M::MAX_DOCUMENTS) // Generate the embeddings for each batch. - .map(|docs| async { - let embeddings = self - .model - .embed_texts(docs.into_iter().map(|(_, text)| text)) - .await?; - Ok::<_, EmbeddingError>(embeddings.into_iter().enumerate().collect::>()) + .map(|text| async { + let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); + + let embeddings = self.model.embed_texts(docs).await?; + Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::>()) }) // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) @@ -152,6 +149,8 @@ impl EmbeddingsBuilder { ) .await?; + println!("{:?}", embeddings); + // Merge the embeddings with their respective documents Ok(docs .into_iter() @@ -168,7 +167,7 @@ impl EmbeddingsBuilder { #[cfg(test)] mod tests { use crate::{ - embeddings::{embed::EmbedError, Embedder, Embedding, EmbeddingModel}, + embeddings::{embed::EmbedError, Embedding, EmbeddingModel, TextEmbedder}, Embed, }; @@ -205,7 +204,7 @@ mod tests { } impl Embed for FakeDefinition { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for definition in &self.definitions { embedder.embed(definition.clone()); } @@ -213,7 +212,7 @@ mod tests { } } - fn fake_definitions() -> Vec { + fn fake_definitions_multiple_text() -> Vec { vec![ FakeDefinition { id: "doc0".to_string(), @@ -232,7 +231,7 @@ mod tests { ] } - fn fake_definitions_2() -> Vec { + fn fake_definitions_multiple_text_2() -> Vec { vec![ FakeDefinition { id: "doc2".to_string(), @@ -252,13 +251,13 @@ mod tests { } impl Embed for FakeDefinitionSingle { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.definition.clone()); Ok(()) } } - fn fake_definitions_single() -> Vec { + fn fake_definitions_single_text() -> Vec { vec![ FakeDefinitionSingle { id: "doc0".to_string(), @@ -272,8 +271,8 @@ mod tests { } #[tokio::test] - async fn test_build_many() { - let fake_definitions = fake_definitions(); + async fn test_build_multiple_text() { + let fake_definitions = fake_definitions_multiple_text(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -306,8 +305,8 @@ mod tests { } #[tokio::test] - async fn test_build_single() { - let fake_definitions = fake_definitions_single(); + async fn test_build_single_text() { + let fake_definitions = fake_definitions_single_text(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -340,9 +339,9 @@ mod tests { } #[tokio::test] - async fn test_build_many_and_single() { - let fake_definitions = fake_definitions(); - let fake_definitions_single = fake_definitions_2(); + async fn test_build_multiple_and_single_text() { + let fake_definitions = fake_definitions_multiple_text(); + let fake_definitions_single = fake_definitions_multiple_text_2(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -375,4 +374,37 @@ mod tests { "Another fake definitions".to_string() ) } + + #[tokio::test] + async fn test_build_string() { + let bindings = fake_definitions_multiple_text(); + let fake_definitions = bindings.iter().map(|def| def.definitions.clone()); + + let fake_model = FakeModel; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.cmp(&fake_definition_2) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.1.len(), 2); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ) + } } diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index cda02d33..7df6c8c9 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -1,11 +1,7 @@ -//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. +//! The module defines the [Embed] trait, which must be implemented for types that can be embedded. -// use crate::one_or_many::OneOrMany; - -use super::EmbeddingModel; - -/// Error type used for when the `extract_embedding_fields` method fails. -/// Used by default implementations of `ExtractEmbeddingFields` for common types. +/// Error type used for when the `Embed.embed` method fails. +/// Used by default implementations of `Embed` for common types. #[derive(Debug, thiserror::Error)] #[error("{0}")] pub struct EmbedError(#[from] Box); @@ -17,14 +13,14 @@ impl EmbedError { } /// Derive this trait for structs whose fields need to be converted to vector embeddings. -/// The `extract_embedding_fields` method returns a `OneOrMany`. This function extracts the fields that need to be embedded and returns them as a list of strings. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// The `embed` method accumulates string values that need to be embedded by adding them to the `TextEmbedder`. +/// If an error occurs, the method should return `EmbedError`. /// # Example /// ```rust /// use std::env; /// /// use serde::{Deserialize, Serialize}; -/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; +/// use rig::{EmptyListError, Embed}; /// /// struct FakeDefinition { /// id: String, @@ -32,129 +28,130 @@ impl EmbedError { /// definitions: String, /// } /// -/// let fake_definition = FakeDefinition { -/// id: "doc1".to_string(), -/// word: "rock".to_string(), -/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() -/// }; -/// -/// impl ExtractEmbeddingFields for FakeDefinition { -/// type Error = EmptyListError; +/// impl Embed for FakeDefinition { +/// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// self.definitions +/// .split(",") +/// .collect::>() +/// .into_iter() +/// .for_each(|s| { +/// embedder.embed(s.to_string()); +/// }); /// -/// fn extract_embedding_fields(&self) -> Result, Self::Error> { -/// // Embeddings only need to be generated for `definition` field. -/// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. -/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); -/// -/// OneOrMany::many(definitions) +/// Ok(()) /// } /// } /// ``` pub trait Embed { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError>; + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; } +/// Accumulates string values that need to be embedded. +/// Used by the `Embed` trait. #[derive(Default)] -pub struct Embedder { +pub struct TextEmbedder { pub texts: Vec, } -impl Embedder { +impl TextEmbedder { pub fn embed(&mut self, text: String) { self.texts.push(text); } } // ================================================================ -// Implementations of ExtractEmbeddingFields for common types +// Implementations of Embed for common types // ================================================================ + impl Embed for String { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.clone()); Ok(()) } } impl Embed for &str { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for i8 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for i16 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for i32 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for i64 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for i128 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for f32 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for f64 { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for bool { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for char { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.to_string()); Ok(()) } } impl Embed for serde_json::Value { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?); Ok(()) } } impl Embed for Vec { - fn embed(&self, embedder: &mut Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for item in self { item.embed(embedder).map_err(EmbedError::new)?; } diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index f008c050..4fc12ae9 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -9,6 +9,6 @@ pub mod embedding; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embed::{Embed, Embedder}; +pub use embed::{Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 1a2e78a6..b3b33c59 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -12,7 +12,7 @@ pub struct EmbeddableTool { } impl Embed for EmbeddableTool { - fn embed(&self, embedder: &mut super::embed::Embedder) -> Result<(), EmbedError> { + fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> { for doc in &self.embedding_docs { embedder.embed(doc.clone()); } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index c3f4fcf8..315b7333 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -84,4 +84,4 @@ pub use embeddings::Embed; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] -pub use rig_derive::ExtractEmbeddingFields; +pub use rig_derive::Embed; diff --git a/rig-core/tests/extract_embedding_fields_macro.rs b/rig-core/tests/embed_macro.rs similarity index 64% rename from rig-core/tests/extract_embedding_fields_macro.rs rename to rig-core/tests/embed_macro.rs index a5db1e24..bb41247a 100644 --- a/rig-core/tests/extract_embedding_fields_macro.rs +++ b/rig-core/tests/embed_macro.rs @@ -1,14 +1,16 @@ -use rig::embeddings::embed::ExtractEmbeddingFieldsError; -use rig::{Embed, OneOrMany}; +use rig::{ + embeddings::{embed::EmbedError, TextEmbedder}, + Embed, +}; use serde::Serialize; -fn serialize(definition: Definition) -> Result, ExtractEmbeddingFieldsError> { - Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(ExtractEmbeddingFieldsError::new)?, - )) +fn serialize(embedder: &mut TextEmbedder, definition: Definition) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition { id: String, word: String, @@ -40,16 +42,17 @@ fn test_custom_embed() { fake_definition.id, fake_definition.word ); + let embedder = &mut TextEmbedder::default(); + fake_definition.embed(embedder).unwrap(); + assert_eq!( - fake_definition.extract_embedding_fields().unwrap(), - OneOrMany::one( - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - ) + embedder.texts.first().unwrap().clone(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition2 { id: String, #[embed] @@ -75,18 +78,18 @@ fn test_custom_and_basic_embed() { fake_definition.id, fake_definition.word ); - assert_eq!( - fake_definition.extract_embedding_fields().unwrap().first(), - "house".to_string() - ); + let embedder = &mut TextEmbedder::default(); + fake_definition.embed(embedder).unwrap(); + + assert_eq!(embedder.texts.first().unwrap().clone(), "house".to_string()); assert_eq!( - fake_definition.extract_embedding_fields().unwrap().rest(), - vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + embedder.texts.last().unwrap().clone(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition3 { id: String, word: String, @@ -108,13 +111,13 @@ fn test_single_embed() { fake_definition.id, fake_definition.word ); - assert_eq!( - fake_definition.extract_embedding_fields().unwrap(), - OneOrMany::one(definition) - ) + let embedder = &mut TextEmbedder::default(); + fake_definition.embed(embedder).unwrap(); + + assert_eq!(embedder.texts.first().unwrap().clone(), definition) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct Company { id: String, company: String, @@ -132,28 +135,21 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let result = company.extract_embedding_fields().unwrap(); + let embedder = &mut TextEmbedder::default(); + company.embed(embedder).unwrap(); assert_eq!( - result, - OneOrMany::many(vec![ + embedder.texts, + vec![ "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() - ]) - .unwrap() + ] ); - - assert_eq!(result.first(), "25".to_string()); - - assert_eq!( - result.rest(), - vec!["30".to_string(), "35".to_string(), "40".to_string()] - ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct Company2 { id: String, #[embed] @@ -172,15 +168,17 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); + let embedder = &mut TextEmbedder::default(); + company.embed(embedder).unwrap(); + assert_eq!( - company.extract_embedding_fields().unwrap(), - OneOrMany::many(vec![ + embedder.texts, + vec![ "Google".to_string(), "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() - ]) - .unwrap() + ] ); } diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index d6dc71e6..954494e5 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -6,7 +6,7 @@ use rig::embeddings::Embedding; use rig::{Embed, OneOrMany}; use serde::Deserialize; -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] +#[derive(Embed, Clone, Deserialize, Debug)] pub struct FakeDefinition { pub id: String, #[embed] diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cdd7de7d..36bc7196 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,15 +3,14 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::Embed; use rig::{ - embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, Embed, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] +#[derive(Embed, Clone, Deserialize, Debug)] struct FakeDefinition { #[serde(rename = "_id")] id: String, From 43fd8167d8c06e3bc80942080860850f623156ff Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 7 Nov 2024 16:22:38 -0500 Subject: [PATCH 3/9] fix(embed trait): fix errors while testing --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 4 +- rig-core/examples/rag_dynamic_tools.rs | 7 +-- rig-core/rig-core-derive/src/basic.rs | 4 +- rig-core/rig-core-derive/src/embed.rs | 10 ++--- rig-core/rig-core-derive/src/lib.rs | 6 +-- rig-core/src/embeddings/builder.rs | 12 ++--- rig-core/src/embeddings/embed.rs | 12 +++-- rig-core/src/embeddings/embedding.rs | 40 +---------------- rig-core/src/embeddings/mod.rs | 4 +- rig-core/src/embeddings/tool.rs | 14 +++--- rig-core/src/tool.rs | 8 ++-- rig-mongodb/examples/vector_search_mongodb.rs | 5 ++- rig-mongodb/src/lib.rs | 45 ++++++------------- 14 files changed, 63 insertions(+), 110 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 576491d5..bcd2ea99 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .documents(toolset.schema()?)? .build() .await?; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 0e8d5a06..ec0f38e4 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -9,8 +9,8 @@ use rig::{ }; use serde::Serialize; -// Shape of data that needs to be RAG'ed. -// A vector search needs to be performed on the definitions, so we derive the `Embed` trait for `FakeDefinition` +// Data to be RAG'ed. +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition` // and tag that field with `#[embed]`. #[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index c72276c7..9214caff 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -155,8 +155,9 @@ async fn main() -> Result<(), anyhow::Error> { .dynamic_tool(Subtract) .build(); - let embeddings = embedding_model - .embed_many(toolset.embedabble_tools()?) + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) + .documents(toolset.schema()?)? + .build() .await?; let index = InMemoryVectorStore::default() diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index b0e5afd7..b9c1e5c4 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -2,7 +2,7 @@ use syn::{parse_quote, Attribute, DataStruct, Meta}; use crate::EMBED; -/// Finds and returns fields with simple #[embed] attribute tags only. +/// Finds and returns fields with simple `#[embed]` attribute tags only. pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { data_struct.fields.iter().filter(|field| { field.attrs.iter().any(|attribute| match attribute { @@ -15,7 +15,7 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator syn::Resu let (basic_targets, basic_target_size) = data_struct.basic(generics); let (custom_targets, custom_target_size) = data_struct.custom()?; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream. // ie. do not implement `Embed` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( @@ -57,17 +57,17 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } trait StructParser { - // Handles fields tagged with #[embed] + // Handles fields tagged with `#[embed]` fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); - // Handles fields tagged with #[embed(embed_with = "...")] + // Handles fields tagged with `#[embed(embed_with = "...")]` fn custom(&self) -> syn::Result<(TokenStream, usize)>; } impl StructParser for DataStruct { fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { let embed_targets = basic_embed_fields(self) - // Iterate over every field tagged with #[embed] + // Iterate over every field tagged with `#[embed]` .map(|field| { add_struct_bounds(generics, &field.ty); @@ -89,7 +89,7 @@ impl StructParser for DataStruct { fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? - // Iterate over every field tagged with #[embed(embed_with = "...")] + // Iterate over every field tagged with `#[embed(embed_with = "...")]` .into_iter() .map(|(field, custom_func_path)| { let field_name = &field.ident; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 6d8b18dc..2d14e8cf 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -8,9 +8,9 @@ mod embed; pub(crate) const EMBED: &str = "embed"; -// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro -// https://doc.rust-lang.org/reference/procedural-macros.html - +/// References: +/// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro +/// https://doc.rust-lang.org/reference/procedural-macros.html #[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index edf7d2e5..8f0a5dd1 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,4 +1,5 @@ -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded +//! and batch generates the embeddings for each object when built. //! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; @@ -110,20 +111,21 @@ impl EmbeddingsBuilder { pub async fn build(self) -> Result)>, EmbeddingError> { use stream::TryStreamExt; - // Store the documents and their texts in a HashMap for easy access + // Store the documents and their texts in a HashMap for easy access. let mut docs = HashMap::new(); let mut texts = HashMap::new(); + // Iterate over all documents in the builder and insert their docs and texts into the lookup stores. for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() { docs.insert(i, doc); texts.insert(i, doc_texts); } - // Compute the embeddings + // Compute the embeddings. let mut embeddings = stream::iter(texts.into_iter()) // Merge the texts of each document into a single list of texts. .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text)))) - // Chunk them into batches the embedding API limit per request. + // Chunk them into batches. Each batch size is at most the embedding API limit per request. .chunks(M::MAX_DOCUMENTS) // Generate the embeddings for each batch. .map(|text| async { @@ -149,8 +151,6 @@ impl EmbeddingsBuilder { ) .await?; - println!("{:?}", embeddings); - // Merge the embeddings with their respective documents Ok(docs .into_iter() diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 7df6c8c9..54c14853 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -1,6 +1,6 @@ -//! The module defines the [Embed] trait, which must be implemented for types that can be embedded. +//! The module defines the [Embed] trait, which must be implemented for types that can be embedded by the `EmbeddingsBuilder`. -/// Error type used for when the `Embed.embed` method fails. +/// Error type used for when the `embed` method fo the `Embed` trait fails. /// Used by default implementations of `Embed` for common types. #[derive(Debug, thiserror::Error)] #[error("{0}")] @@ -12,7 +12,7 @@ impl EmbedError { } } -/// Derive this trait for structs whose fields need to be converted to vector embeddings. +/// Derive this trait for objects that need to be converted to vector embeddings. /// The `embed` method accumulates string values that need to be embedded by adding them to the `TextEmbedder`. /// If an error occurs, the method should return `EmbedError`. /// # Example @@ -62,6 +62,12 @@ impl TextEmbedder { } } +pub fn to_text(item: impl Embed) -> Result, EmbedError> { + let mut embedder = TextEmbedder::default(); + item.embed(&mut embedder)?; + Ok(embedder.texts) +} + // ================================================================ // Implementations of Embed for common types // ================================================================ diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 8748c6f1..d033f57e 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -10,10 +10,6 @@ use serde::{Deserialize, Serialize}; -use crate::OneOrMany; - -use super::{Embed, EmbeddingsBuilder}; - #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) @@ -51,7 +47,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { documents: impl IntoIterator + Send, ) -> impl std::future::Future, EmbeddingError>> + Send; - /// Embed a single text document + /// Embed a single text document. fn embed_text( &self, document: &str, @@ -64,40 +60,6 @@ pub trait EmbeddingModel: Clone + Sync + Send { .expect("There should be at least one embedding")) } } - - /// Embed a single document - fn embed( - &self, - document: T, - ) -> impl std::future::Future, EmbeddingError>> + Send - { - async { - Ok(self - .embed_many(vec![document]) - .await? - .pop() - .map(|(_, embedding)| embedding) - .expect("There should be at least one embedding")) - } - } - - /// Embed multiple documents in a single request - fn embed_many + Send>( - &self, - documents: I, - ) -> impl std::future::Future)>, EmbeddingError>> + Send - where - ::IntoIter: std::marker::Send, - { - async { - let builder = EmbeddingsBuilder::new(self.clone()); - builder - .documents(documents) - .map_err(|err| EmbeddingError::DocumentError(Box::new(err)))? - .build() - .await - } - } } /// Struct that holds a single document and its embedding. diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 4fc12ae9..bd2f72dc 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -9,6 +9,6 @@ pub mod embedding; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embed::{Embed, EmbedError, TextEmbedder}; +pub use embed::{to_text, Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; -pub use tool::EmbeddableTool; +pub use tool::ToolSchema; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index b3b33c59..bcea7c7d 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -5,13 +5,13 @@ use super::embed::EmbedError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] -pub struct EmbeddableTool { +pub struct ToolSchema { pub name: String, pub context: serde_json::Value, pub embedding_docs: Vec, } -impl Embed for EmbeddableTool { +impl Embed for ToolSchema { fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> { for doc in &self.embedding_docs { embedder.embed(doc.clone()); @@ -20,13 +20,13 @@ impl Embed for EmbeddableTool { } } -impl EmbeddableTool { - /// Convert item that implements ToolEmbeddingDyn to an EmbeddableTool. +impl ToolSchema { + /// Convert item that implements ToolEmbeddingDyn to an ToolSchema. /// # Example /// ```rust /// use rig::{ /// completion::ToolDefinition, - /// embeddings::EmbeddableTool, + /// embeddings::ToolSchema, /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, /// }; /// use serde_json::json; @@ -77,13 +77,13 @@ impl EmbeddableTool { /// fn context(&self) -> Self::Context {} /// } /// - /// let tool = EmbeddableTool::try_from(&Nothing).unwrap(); + /// let tool = ToolSchema::try_from(&Nothing).unwrap(); /// /// assert_eq!(tool.name, "nothing".to_string()); /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); /// ``` pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { - Ok(EmbeddableTool { + Ok(ToolSchema { name: tool.name(), context: tool.context().map_err(EmbedError::new)?, embedding_docs: tool.embedding_docs(), diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index f8d7d7a8..a70f3c3b 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embed::EmbedError, tool::EmbeddableTool}, + embeddings::{embed::EmbedError, tool::ToolSchema}, }; #[derive(Debug, thiserror::Error)] @@ -327,15 +327,15 @@ impl ToolSet { Ok(docs) } - /// Convert tools in self to objects of type EmbeddableTool. + /// Convert tools in self to objects of type ToolSchema. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result, EmbedError> { + pub fn schema(&self) -> Result, EmbedError> { self.tools .values() .filter_map(|tool_type| { if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(&**tool)) + Some(ToolSchema::try_from(&**tool)) } else { None } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 36bc7196..b0867064 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -6,7 +6,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, Embed, }; -use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use rig_mongodb::{MongoDbVectorIndex, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. @@ -96,7 +96,8 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store. // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorStore::new(collection).index( + let index = MongoDbVectorIndex::new( + collection, model, "vector_index", SearchParams::new("embedding"), diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index fd440669..3631cc8d 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -11,9 +11,10 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } +/// A vector index for a MongoDB collection. /// # Example /// ``` -/// use rig_mongodb::{MongoDbVectorStore, SearchParams}; +/// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -26,37 +27,13 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. -/// let index = MongoDbVectorStore::new(collection).index( +/// let index = MongoDbVectorIndex::new( +/// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. /// ); /// ``` -pub struct MongoDbVectorStore { - collection: mongodb::Collection, -} - -impl MongoDbVectorStore { - /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { - Self { collection } - } - - /// Create a new `MongoDbVectorIndex` from an existing `MongoDbVectorStore`. - /// - /// The index (of type "vector") must already exist for the MongoDB collection. - /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index( - &self, - model: M, - index_name: &str, - search_params: SearchParams, - ) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) - } -} - -/// A vector index for a MongoDB collection. pub struct MongoDbVectorIndex { collection: mongodb::Collection, model: M, @@ -100,6 +77,10 @@ impl MongoDbVectorIndex { } impl MongoDbVectorIndex { + /// Create a new `MongoDbVectorIndex`. + /// + /// The index (of type "vector") must already exist for the MongoDB collection. + /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. pub fn new( collection: mongodb::Collection, model: M, @@ -167,7 +148,7 @@ impl VectorStoreIndex /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. /// # Example /// ``` - /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -188,7 +169,8 @@ impl VectorStoreIndex /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// - /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. @@ -242,7 +224,7 @@ impl VectorStoreIndex /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. /// # Example /// ``` - /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -255,7 +237,8 @@ impl VectorStoreIndex /// /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. - /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. From b4762175dbb522ca9809e980a6545e846247279a Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 7 Nov 2024 16:58:38 -0500 Subject: [PATCH 4/9] fix(lancedb): examples --- rig-lancedb/examples/vector_search_local_ann.rs | 1 - rig-lancedb/examples/vector_search_s3_ann.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 7ffd6b12..84679e3f 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -35,7 +35,6 @@ async fn main() -> Result<(), anyhow::Error> { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) - .collect(), )? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 824deda0..160dfa10 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -41,7 +41,6 @@ async fn main() -> Result<(), anyhow::Error> { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) - .collect(), )? .build() .await?; From 3e1ff232af223198e5cd63e61a91c42f8dc0e28a Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 7 Nov 2024 16:59:46 -0500 Subject: [PATCH 5/9] docs: fix hyperlink --- rig-core/rig-core-derive/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 2d14e8cf..4ce20cfa 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -9,8 +9,8 @@ mod embed; pub(crate) const EMBED: &str = "embed"; /// References: -/// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro -/// https://doc.rust-lang.org/reference/procedural-macros.html +/// +/// #[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); From 1c7208074137e4cce21bf9ec81875522e54d5a02 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 7 Nov 2024 17:02:48 -0500 Subject: [PATCH 6/9] fmt: cargo fmt --- rig-core/examples/rag_dynamic_tools.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 9214caff..9f0b7f7a 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, From add9d0b4a817f3ef07215c6730c9d13108711a11 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 14 Nov 2024 18:06:36 -0500 Subject: [PATCH 7/9] PR; make requested changes --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/src/embeddings/embed.rs | 8 ++++---- rig-core/src/embeddings/mod.rs | 2 +- rig-core/src/lib.rs | 2 +- rig-core/src/tool.rs | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index bcd2ea99..149b1ce4 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.schema()?)? + .documents(toolset.schemas()?)? .build() .await?; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ec0f38e4..cecd20ce 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -9,7 +9,7 @@ use rig::{ }; use serde::Serialize; -// Data to be RAG'ed. +// Data to be RAGged. // A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition` // and tag that field with `#[embed]`. #[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 9f0b7f7a..459b017b 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.schema()?)? + .documents(toolset.schemas()?)? .build() .await?; diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 54c14853..7ff43198 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -20,7 +20,7 @@ impl EmbedError { /// use std::env; /// /// use serde::{Deserialize, Serialize}; -/// use rig::{EmptyListError, Embed}; +/// use rig::Embed; /// /// struct FakeDefinition { /// id: String, @@ -35,8 +35,6 @@ impl EmbedError { /// // That way, different embeddings can be generated for each definition in the definitions string. /// self.definitions /// .split(",") -/// .collect::>() -/// .into_iter() /// .for_each(|s| { /// embedder.embed(s.to_string()); /// }); @@ -62,7 +60,9 @@ impl TextEmbedder { } } -pub fn to_text(item: impl Embed) -> Result, EmbedError> { +/// Client-side function to convert an object that implements the `Embed` trait to a vector of strings. +/// Similar to `serde`'s `serde_json::to_string()` function +pub fn to_texts(item: impl Embed) -> Result, EmbedError> { let mut embedder = TextEmbedder::default(); item.embed(&mut embedder)?; Ok(embedder.texts) diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index bd2f72dc..1ae16436 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -9,6 +9,6 @@ pub mod embedding; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embed::{to_text, Embed, EmbedError, TextEmbedder}; +pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; pub use tool::ToolSchema; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 315b7333..6c5db7ab 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -80,7 +80,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::Embed; +pub use embeddings::{to_texts, Embed}; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index a70f3c3b..198be04c 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -330,7 +330,7 @@ impl ToolSet { /// Convert tools in self to objects of type ToolSchema. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn schema(&self) -> Result, EmbedError> { + pub fn schemas(&self) -> Result, EmbedError> { self.tools .values() .filter_map(|tool_type| { From 64f2db6a033661583feafaad0e1f4cdbc637aba6 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Fri, 15 Nov 2024 10:09:44 -0500 Subject: [PATCH 8/9] fix: change visibility of struct field --- rig-core/src/embeddings/embed.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 7ff43198..2fd26c2a 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -51,7 +51,7 @@ pub trait Embed { /// Used by the `Embed` trait. #[derive(Default)] pub struct TextEmbedder { - pub texts: Vec, + pub(crate) texts: Vec, } impl TextEmbedder { From 4d41c353bbe06d267f73e5c54e488b7627c1b53e Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Fri, 15 Nov 2024 10:17:48 -0500 Subject: [PATCH 9/9] fix: failing tests --- rig-core/tests/embed_macro.rs | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/rig-core/tests/embed_macro.rs b/rig-core/tests/embed_macro.rs index bb41247a..778b70bd 100644 --- a/rig-core/tests/embed_macro.rs +++ b/rig-core/tests/embed_macro.rs @@ -1,6 +1,6 @@ use rig::{ embeddings::{embed::EmbedError, TextEmbedder}, - Embed, + to_texts, Embed, }; use serde::Serialize; @@ -42,11 +42,8 @@ fn test_custom_embed() { fake_definition.id, fake_definition.word ); - let embedder = &mut TextEmbedder::default(); - fake_definition.embed(embedder).unwrap(); - assert_eq!( - embedder.texts.first().unwrap().clone(), + to_texts(fake_definition).unwrap().first().unwrap().clone(), "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -78,13 +75,12 @@ fn test_custom_and_basic_embed() { fake_definition.id, fake_definition.word ); - let embedder = &mut TextEmbedder::default(); - fake_definition.embed(embedder).unwrap(); + let texts = to_texts(fake_definition).unwrap(); - assert_eq!(embedder.texts.first().unwrap().clone(), "house".to_string()); + assert_eq!(texts.first().unwrap().clone(), "house".to_string()); assert_eq!( - embedder.texts.last().unwrap().clone(), + texts.last().unwrap().clone(), "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) } @@ -111,10 +107,10 @@ fn test_single_embed() { fake_definition.id, fake_definition.word ); - let embedder = &mut TextEmbedder::default(); - fake_definition.embed(embedder).unwrap(); - - assert_eq!(embedder.texts.first().unwrap().clone(), definition) + assert_eq!( + to_texts(fake_definition).unwrap().first().unwrap().clone(), + definition + ) } #[derive(Embed)] @@ -135,11 +131,8 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let embedder = &mut TextEmbedder::default(); - company.embed(embedder).unwrap(); - assert_eq!( - embedder.texts, + to_texts(company).unwrap(), vec![ "25".to_string(), "30".to_string(), @@ -168,11 +161,8 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); - let embedder = &mut TextEmbedder::default(); - company.embed(embedder).unwrap(); - assert_eq!( - embedder.texts, + to_texts(company).unwrap(), vec![ "Google".to_string(), "25".to_string(),