From 223139e8325a3ecca6d04cecc42893def6beb7db Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 22 Oct 2024 15:41:53 -0400 Subject: [PATCH 1/6] docs: add and fix docstrings and examples --- rig-core/examples/calculator_chatbot.rs | 7 +- rig-core/examples/rag.rs | 15 +- rig-core/examples/rag_dynamic_tools.rs | 7 +- rig-core/examples/vector_search.rs | 9 +- rig-core/examples/vector_search_cohere.rs | 9 +- rig-core/src/lib.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 27 +++ rig-core/src/vector_store/mod.rs | 3 + .../examples/vector_search_local_ann.rs | 17 +- .../examples/vector_search_local_enn.rs | 13 +- rig-lancedb/examples/vector_search_s3_ann.rs | 13 +- rig-lancedb/src/lib.rs | 147 ++++++-------- rig-mongodb/src/lib.rs | 190 +++++++----------- 13 files changed, 196 insertions(+), 263 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 3f9f0b1b..723bfada 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -252,12 +252,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, "name")? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ab1387a1..d03902d1 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -10,8 +10,9 @@ use rig::{ use serde::Serialize; // Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] +// A vector search needs to be performed on the definitions, so we derive the `Embeddable` trait for `FakeDefinition` +// and tag that field with `#[embed]`. +#[derive(Embeddable, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] @@ -26,6 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ FakeDefinition { @@ -54,14 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index bdad5109..c3a2c251 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -161,12 +161,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, "name")? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 5aebe12d..b40f271c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -57,14 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(model); let results = index diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 54adc598..da1e474b 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -58,14 +58,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(search_model); let results = index diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 07c59f96..6b337073 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,7 +78,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::embeddable::Embeddable; +pub use embeddings::Embeddable; pub use one_or_many::OneOrMany; #[cfg(feature = "derive")] diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 208b5f13..5ab9d6e7 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,6 +76,33 @@ impl InMemoryVectorStore { Ok(self) } + /// Add documents to the store. Define the name of the field in the document that contains the id. + /// Returns the store with the added documents. + pub fn add_documents_with_id( + mut self, + documents: Vec<(D, OneOrMany)>, + id_field: &str, + ) -> Result { + for (doc, embeddings) in documents { + if let serde_json::Value::Object(o) = + serde_json::to_value(&doc).map_err(VectorStoreError::JsonError)? + { + match o.get(id_field) { + Some(serde_json::Value::String(s)) => { + self.embeddings.insert(s.clone(), (doc, embeddings)); + } + _ => { + return Err(VectorStoreError::MissingIdError(format!( + "Document does not have a field {id_field}" + ))); + } + } + }; + } + + Ok(self) + } + /// Get the document by its id and deserialize it into the given type. pub fn get_document Deserialize<'a>>( &self, diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 38e45d0e..044d8c2a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,6 +17,9 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box), + + #[error("Missing Id: {0}")] + MissingIdError(String), } /// Trait for vector store indexes diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..7ffd6b12 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,12 +3,12 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -40,12 +40,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -61,10 +62,10 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default(); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index - let results = vector_store + let results = vector_store_index .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..859442be 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,11 +3,11 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -33,17 +33,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..824deda0 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,11 +4,11 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -46,12 +46,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -73,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 18d87d7f..eaaffbe3 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -20,79 +20,17 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } +/// Type on which vector searches can be performed for a lanceDb table. /// # Example /// ``` -/// use std::{env, sync::Arc}; - -/// use arrow_array::RecordBatchIterator; -/// use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; -/// use lancedb::index::vector::IvfPqIndexBuilder; -/// use rig::vector_store::VectorStoreIndex; -/// use rig::{ -/// embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, -/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// }; -/// use rig_lancedb::{LanceDbVectorStore, SearchParams}; -/// -/// #[path = "../examples/fixtures/lib.rs"] -/// mod fixture; -/// -/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); -/// -/// // Select an embedding model. -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -/// -/// // Initialize LanceDB locally. -/// let db = lancedb::connect("data/lancedb-store").execute().await?; -/// -/// // Generate embeddings for the test data. -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions())? -/// // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. -/// .documents( -/// (0..256) -/// .map(|i| FakeDefinition { -/// id: format!("doc{}", i), -/// definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() -/// }) -/// .collect(), -/// )? -/// .build() -/// .await?; +/// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +/// use rig::embeddings::EmbeddingModel; /// -/// // Create table with embeddings. -/// let record_batch = as_record_batch(embeddings, model.ndims()); -/// let table = db -/// .create_table( -/// "definitions", -/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), -/// ) -/// .execute() -/// .await?; -/// -/// // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information -/// table -/// .create_index( -/// &["embedding"], -/// lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), -/// ) -/// .execute() -/// .await?; -/// -/// // Define search_params params that will be used by the vector store to perform the vector search. -/// let search_params = SearchParams::default(); -/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; -/// -/// // Query the index -/// let results = vector_store -/// .top_n::("My boss says I zindle too much, what does that mean?", 1) -/// .await?; -/// -/// println!("Results: {:?}", results); +/// fn create_index(table: lancedb::Table, model: EmbeddingModel) { +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; +/// } /// ``` -pub struct LanceDbVectorStore { +pub struct LanceDbVectorIndex { /// Defines which model is used to generate embeddings for the vector store. model: M, /// LanceDB table containing embeddings. @@ -103,7 +41,24 @@ pub struct LanceDbVectorStore { search_params: SearchParams, } -impl LanceDbVectorStore { +impl LanceDbVectorIndex { + /// Create an instance of `LanceDbVectorIndex` with an existing table and model. + /// Define the id field name of the table. + /// Define search parameters that will be used to perform vector searches on the table. + pub async fn new( + table: lancedb::Table, + model: M, + id_field: &str, + search_params: SearchParams, + ) -> Result { + Ok(Self { + table, + model, + id_field: id_field.to_string(), + search_params, + }) + } + /// Apply the search_params to the vector query. /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait. fn build_query(&self, mut query: VectorQuery) -> VectorQuery { @@ -155,6 +110,10 @@ pub enum SearchType { } /// Parameters used to perform a vector search on a LanceDb table. +/// # Example +/// ``` +/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine); +/// ``` #[derive(Debug, Clone, Default)] pub struct SearchParams { distance_type: Option, @@ -215,26 +174,22 @@ impl SearchParams { } } -impl LanceDbVectorStore { - /// Create an instance of `LanceDbVectorStore` with an existing table and model. - /// Define the id field name of the table. - /// Define search parameters that will be used to perform vector searches on the table. - pub async fn new( - table: lancedb::Table, - model: M, - id_field: &str, - search_params: SearchParams, - ) -> Result { - Ok(Self { - table, - model, - id_field: id_field.to_string(), - search_params, - }) - } -} - -impl VectorStoreIndex for LanceDbVectorStore { +impl VectorStoreIndex for LanceDbVectorIndex { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n Deserialize<'a> + Send>( &self, query: &str, @@ -269,6 +224,18 @@ impl VectorStoreIndex for LanceDbVectorStore .collect() } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n_ids( &self, query: &str, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 3201c648..a803d385 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,134 +7,35 @@ use rig::{ }; use serde::Deserialize; +fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + /// # Example /// ``` -/// use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -/// use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -/// use serde::{Deserialize, Serialize}; -/// use std::env; - -/// use rig::Embeddable; -/// use rig::{ -/// embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, -/// }; /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; - -/// // Shape of data that needs to be RAG'ed. -/// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug)] -/// struct FakeDefinition { -/// #[serde(rename = "_id")] -/// id: String, -/// #[embed] -/// definition: String, -/// } - -/// #[derive(Clone, Deserialize, Debug, Serialize)] -/// struct Link { -/// word: String, -/// link: String, -/// } - -/// // Shape of the document to be stored in MongoDB, with embeddings. -/// #[derive(Serialize, Debug)] +/// use rig::embeddings::EmbeddingModel; +/// +/// #[derive(serde::Serialize, Debug)] /// struct Document { /// #[serde(rename = "_id")] /// id: String, /// definition: String, /// embedding: Vec, /// } -/// // Initialize OpenAI client -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); - -/// // Initialize MongoDB client -/// let mongodb_connection_string = -/// env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); -/// let options = ClientOptions::parse(mongodb_connection_string) -/// .await -/// .expect("MongoDB connection string should be valid"); - -/// let mongodb_client = -/// MongoClient::with_options(options).expect("MongoDB client options should be valid"); - -/// // Initialize MongoDB vector store -/// let collection: Collection = mongodb_client -/// .database("knowledgebase") -/// .collection("context"); - -/// // Select the embedding model and generate our embeddings -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - -/// let fake_definitions = vec![ -/// FakeDefinition { -/// id: "doc0".to_string(), -/// definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc1".to_string(), -/// definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc2".to_string(), -/// definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), -/// } -/// ]; - -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions)? -/// .build() -/// .await?; - -/// let mongo_documents = embeddings -/// .iter() -/// .map( -/// |(FakeDefinition { id, definition, .. }, embedding)| Document { -/// id: id.clone(), -/// definition: definition.clone(), -/// embedding: embedding.first().vec.clone(), -/// }, -/// ) -/// .collect::>(); - -/// match collection.insert_many(mongo_documents, None).await { -/// Ok(_) => println!("Documents added successfully"), -/// Err(e) => println!("Error adding documents: {:?}", e), -/// }; - -/// // Create a vector index on our vector store. -/// // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. -/// // IMPORTANT: Reuse the same model that was used to generate the embeddings -/// let index = MongoDbVectorStore::new(collection).index( -/// model, -/// "vector_index", -/// SearchParams::new("embedding"), -/// ); - -/// // Query the index -/// let results = index -/// .top_n::("What is a linglingdong?", 1) -/// .await?; - -/// println!("Results: {:?}", results); - -/// let id_results = index -/// .top_n_ids("What is a linglingdong?", 1) -/// .await? -/// .into_iter() -/// .map(|(score, id)| (score, id)) -/// .collect::>(); - -/// println!("ID results: {:?}", id_results); +/// +/// fn create_index(collection: mongodb::Collection, model: EmbeddingModel) { +/// let index = MongoDbVectorStore::new(collection).index( +/// model, +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. +/// ); +/// } /// ``` pub struct MongoDbVectorStore { collection: mongodb::Collection, } -fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. pub fn new(collection: mongodb::Collection) -> Self { @@ -263,6 +164,40 @@ impl SearchParams { impl VectorStoreIndex for MongoDbVectorIndex { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// #[derive(serde::Deserialize, Debug)] + /// struct Definition { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// } + /// + /// fn execute_search(collection: mongodb::Collection, model: EmbeddingModel) { + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n Deserialize<'a> + Send>( &self, query: &str, @@ -303,6 +238,33 @@ impl VectorStoreIndex Ok(results) } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// fn execute_search(collection: mongodb::Collection, model: EmbeddingModel) { + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n_ids( &self, query: &str, From ed9e038876b64c76d82f4daf23b116767d483316 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 22 Oct 2024 16:36:07 -0400 Subject: [PATCH 2/6] docs: add more doc tests --- Cargo.lock | 36 ++++++++++++++++ rig-core/Cargo.toml | 3 +- rig-core/src/embeddings/builder.rs | 12 +++++- rig-core/src/embeddings/embeddable.rs | 61 +++++++++++++------------- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/embeddings/tool.rs | 62 ++++++++++++++++++++++++++- rig-core/src/lib.rs | 2 +- 7 files changed, 143 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75f67709..1a0a5eed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,28 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -4003,6 +4025,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-test", "tracing", "tracing-subscriber", ] @@ -5122,6 +5145,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.12" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index ea910406..ff2df2da 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -29,6 +29,7 @@ rig-derive = { path = "./rig-core-derive", optional = true } anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" +tokio-test = "0.4.4" [features] derive = ["dep:rig-derive"] @@ -47,4 +48,4 @@ required-features = ["derive"] [[example]] name = "vector_search_cohere" -required-features = ["derive"] \ No newline at end of file +required-features = ["derive"] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 887dbf0e..fed2ef06 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -73,6 +73,7 @@ impl EmbeddingsBuilder { /// /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// +/// # tokio_test::block_on(async { /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ /// FakeDefinition { @@ -99,9 +100,16 @@ impl EmbeddingsBuilder { /// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() /// ] /// }, -/// ])? +/// ]) +/// .unwrap() /// .build() -/// .await?; +/// .await +/// .unwrap(); +/// +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc0" && embeddings.len() == 2), true); +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc1" && embeddings.len() == 2), true); +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); +/// }) /// ``` impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index f5a69fd6..1d70ffcf 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,33 +1,4 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! # Example -//! ```rust -//! use std::env; -//! -//! use serde::{Deserialize, Serialize}; -//! use rig::OneOrMany; -//! -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! definition: String, -//! } -//! -//! let fake_definition = FakeDefinition { -//! id: "doc1".to_string(), -//! word: "hello".to_string(), -//! definition: "used as a greeting or to begin a conversation".to_string() -//! }; -//! -//! impl Embeddable for FakeDefinition { -//! type Error = anyhow::Error; -//! -//! fn embeddable(&self) -> Result, Self::Error> { -//! // Embeddigns only need to be generated for `definition` field. -//! // Select it from the struct and return it as a single item. -//! Ok(OneOrMany::one(self.definition.clone())) -//! } -//! } -//! ``` use crate::one_or_many::OneOrMany; @@ -46,6 +17,38 @@ impl EmbeddableError { /// Trait for types that can be embedded. /// The `embeddable` method returns a `OneOrMany` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, Embeddable}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl Embeddable for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn embeddable(&self) -> Result, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` pub trait Embeddable { type Error: std::error::Error + Sync + Send + 'static; diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 763e0f30..b8ad9b62 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -11,3 +11,4 @@ pub mod tool; pub use builder::EmbeddingsBuilder; pub use embeddable::Embeddable; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 139b11b8..c7c23b87 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -20,7 +20,67 @@ impl Embeddable for EmbeddableTool { } impl EmbeddableTool { - /// Convert item that implements ToolEmbedding to an EmbeddableTool. + /// Convert item that implements ToolEmbeddingDyn to an EmbeddableTool. + /// # Example + /// ```rust + /// use rig::{ + /// completion::ToolDefinition, + /// embeddings::EmbeddableTool, + /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, + /// }; + /// use serde_json::json; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Math error")] + /// struct NothingError; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Init error")] + /// struct InitError; + /// + /// struct Nothing; + /// impl Tool for Nothing { + /// const NAME: &'static str = "nothing"; + /// + /// type Error = NothingError; + /// type Args = (); + /// type Output = (); + /// + /// async fn definition(&self, _prompt: String) -> ToolDefinition { + /// serde_json::from_value(json!({ + /// "name": "nothing", + /// "description": "nothing", + /// "parameters": {} + /// })) + /// .expect("Tool Definition") + /// } + /// + /// async fn call(&self, args: Self::Args) -> Result { + /// Ok(()) + /// } + /// } + /// + /// impl ToolEmbedding for Nothing { + /// type InitError = InitError; + /// type Context = (); + /// type State = (); + /// + /// fn init(_state: Self::State, _context: Self::Context) -> Result { + /// Ok(Nothing) + /// } + /// + /// fn embedding_docs(&self) -> Vec { + /// vec!["Do nothing.".into()] + /// } + /// + /// fn context(&self) -> Self::Context {} + /// } + /// + /// let tool = EmbeddableTool::try_from(&Nothing).unwrap(); + /// + /// assert_eq!(tool.name, "nothing".to_string()); + /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); + /// ``` pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { Ok(EmbeddableTool { name: tool.name(), diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 6b337073..2ef24051 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -79,7 +79,7 @@ pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::Embeddable; -pub use one_or_many::OneOrMany; +pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; From c502ea5a623817126fc5043051264beaed54585c Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 23 Oct 2024 10:32:02 -0400 Subject: [PATCH 3/6] feat: rename Embeddable trait to ExtractEmbeddingFields --- rig-core/examples/rag.rs | 6 +- rig-core/examples/vector_search.rs | 4 +- rig-core/examples/vector_search_cohere.rs | 4 +- rig-core/rig-core-derive/src/basic.rs | 4 +- rig-core/rig-core-derive/src/embeddable.rs | 16 +- rig-core/rig-core-derive/src/lib.rs | 2 +- rig-core/src/embeddings/builder.rs | 40 ++--- rig-core/src/embeddings/embeddable.rs | 163 ------------------ .../embeddings/extract_embedding_fields.rs | 163 ++++++++++++++++++ rig-core/src/embeddings/mod.rs | 4 +- rig-core/src/embeddings/tool.rs | 16 +- rig-core/src/lib.rs | 4 +- rig-core/src/providers/cohere.rs | 4 +- rig-core/src/providers/openai.rs | 4 +- rig-core/src/tool.rs | 4 +- rig-core/tests/embeddable_macro.rs | 30 ++-- rig-lancedb/examples/fixtures/lib.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 4 +- 18 files changed, 238 insertions(+), 238 deletions(-) delete mode 100644 rig-core/src/embeddings/embeddable.rs create mode 100644 rig-core/src/embeddings/extract_embedding_fields.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d03902d1..ab2f7767 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -5,14 +5,14 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - Embeddable, + ExtractEmbeddingFields, }; use serde::Serialize; // Shape of data that needs to be RAG'ed. -// A vector search needs to be performed on the definitions, so we derive the `Embeddable` trait for `FakeDefinition` +// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition` // and tag that field with `#[embed]`. -#[derive(Embeddable, Serialize, Clone, Debug, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index b40f271c..36bb8d7e 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index da1e474b..003d39f5 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 86bb13ad..39b72018 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator syn::Resu let (custom_targets, custom_target_size) = data_struct.custom()?; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. + // ie. do not implement `ExtractEmbeddingFields` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( name, @@ -34,7 +34,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu _ => { return Err(syn::Error::new_spanned( input, - "Embeddable derive macro should only be used on structs", + "ExtractEmbeddingFields derive macro should only be used on structs", )) } }; @@ -42,18 +42,18 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: Embeddable trait is imported with the macro. + // Note: `ExtractEmbeddingFields` trait is imported with the macro. - impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Error = rig::embeddings::embeddable::EmbeddableError; + impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { + type Error = rig::embeddings::embeddable::ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { + fn extract_embedding_fields(&self) -> Result, Self::Error> { #target_stream; rig::OneOrMany::merge( embed_targets.into_iter() .collect::, _>>()? - ).map_err(rig::embeddings::embeddable::EmbeddableError::new) + ).map_err(rig::embeddings::embeddable::ExtractEmbeddingFieldsError::new) } } }; @@ -87,7 +87,7 @@ impl StructParser for DataStruct { if !embed_targets.is_empty() { ( quote! { - vec![#(#embed_targets.embeddable()),*] + vec![#(#embed_targets.extract_embedding_fields()),*] }, embed_targets.len(), ) diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index d28a0d78..042f7ca9 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -11,7 +11,7 @@ pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embeddable, attributes(embed))] +#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index fed2ef06..b6138ef5 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,22 +1,22 @@ //! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! Only types that implement the [ExtractEmbeddingFields] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use crate::{ - embeddings::{Embeddable, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{ExtractEmbeddingFields, Embedding, EmbeddingError, EmbeddingModel}, OneOrMany, }; /// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { model: M, documents: Vec<(T, OneOrMany)>, } -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,18 +25,18 @@ impl EmbeddingsBuilder { } } - /// Add a document that implements `Embeddable` to the builder. + /// Add a document that implements `ExtractEmbeddingFields` to the builder. pub fn document(mut self, document: T) -> Result { - let embed_targets = document.embeddable()?; + let embed_targets = document.extract_embedding_fields()?; self.documents.push((document, embed_targets)); Ok(self) } - /// Add many documents that implement `Embeddable` to the builder. + /// Add many documents that implement `ExtractEmbeddingFields` to the builder. pub fn documents(mut self, documents: Vec) -> Result { for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; + let embed_targets = doc.extract_embedding_fields()?; self.documents.push((doc, embed_targets)); } @@ -53,13 +53,13 @@ impl EmbeddingsBuilder { /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// Embeddable, +/// ExtractEmbeddingFields, /// }; /// use serde::{Deserialize, Serialize}; /// /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] /// struct FakeDefinition { /// id: String, /// word: String, @@ -111,7 +111,7 @@ impl EmbeddingsBuilder { /// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); /// }) /// ``` -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). pub async fn build(self) -> Result)>, EmbeddingError> { @@ -179,8 +179,8 @@ impl EmbeddingsBuilder, } - impl Embeddable for FakeDefinition { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinition { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result, Self::Error> { + crate::OneOrMany::many(self.definitions.clone()).map_err(ExtractEmbeddingFieldsError::new) } } @@ -261,10 +261,10 @@ mod tests { definition: String, } - impl Embeddable for FakeDefinitionSingle { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinitionSingle { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { + fn extract_embedding_fields(&self) -> Result, Self::Error> { Ok(crate::OneOrMany::one(self.definition.clone())) } } diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs deleted file mode 100644 index 1d70ffcf..00000000 --- a/rig-core/src/embeddings/embeddable.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. - -use crate::one_or_many::OneOrMany; - -/// Error type used for when the `embeddable` method fails. -/// Used by default implementations of `Embeddable` for common types. -#[derive(Debug, thiserror::Error)] -#[error("{0}")] -pub struct EmbeddableError(#[from] Box); - -impl EmbeddableError { - pub fn new(error: E) -> Self { - EmbeddableError(Box::new(error)) - } -} - -/// Trait for types that can be embedded. -/// The `embeddable` method returns a `OneOrMany` which contains strings for which embeddings will be generated by the embeddings builder. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. -/// # Example -/// ```rust -/// use std::env; -/// -/// use serde::{Deserialize, Serialize}; -/// use rig::{OneOrMany, EmptyListError, Embeddable}; -/// -/// struct FakeDefinition { -/// id: String, -/// word: String, -/// definitions: String, -/// } -/// -/// let fake_definition = FakeDefinition { -/// id: "doc1".to_string(), -/// word: "rock".to_string(), -/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() -/// }; -/// -/// impl Embeddable for FakeDefinition { -/// type Error = EmptyListError; -/// -/// fn embeddable(&self) -> Result, Self::Error> { -/// // Embeddings only need to be generated for `definition` field. -/// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. -/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); -/// -/// OneOrMany::many(definitions) -/// } -/// } -/// ``` -pub trait Embeddable { - type Error: std::error::Error + Sync + Send + 'static; - - fn embeddable(&self) -> Result, Self::Error>; -} - -// ================================================================ -// Implementations of Embeddable for common types -// ================================================================ -impl Embeddable for String { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.clone())) - } -} - -impl Embeddable for i8 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i16 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i128 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for bool { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for char { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for serde_json::Value { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::one( - serde_json::to_string(self).map_err(EmbeddableError::new)?, - )) - } -} - -impl Embeddable for Vec { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result, Self::Error> { - let items = self - .iter() - .map(|item| item.embeddable()) - .collect::, _>>() - .map_err(EmbeddableError::new)?; - - OneOrMany::merge(items).map_err(EmbeddableError::new) - } -} diff --git a/rig-core/src/embeddings/extract_embedding_fields.rs b/rig-core/src/embeddings/extract_embedding_fields.rs new file mode 100644 index 00000000..e62d43c6 --- /dev/null +++ b/rig-core/src/embeddings/extract_embedding_fields.rs @@ -0,0 +1,163 @@ +//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. + +use crate::one_or_many::OneOrMany; + +/// Error type used for when the `extract_embedding_fields` method fails. +/// Used by default implementations of `ExtractEmbeddingFields` for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct ExtractEmbeddingFieldsError(#[from] Box); + +impl ExtractEmbeddingFieldsError { + pub fn new(error: E) -> Self { + ExtractEmbeddingFieldsError(Box::new(error)) + } +} + +/// Derive this trait for structs whose fields need to be converted to vector embeddings. +/// The `extract_embedding_fields` method returns a `OneOrMany`. This function extracts the fields that need to be embedded and returns them as a list of strings. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl ExtractEmbeddingFields for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn extract_embedding_fields(&self) -> Result, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` +pub trait ExtractEmbeddingFields { + type Error: std::error::Error + Sync + Send + 'static; + + fn extract_embedding_fields(&self) -> Result, Self::Error>; +} + +// ================================================================ +// Implementations of ExtractEmbeddingFields for common types +// ================================================================ +impl ExtractEmbeddingFields for String { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.clone())) + } +} + +impl ExtractEmbeddingFields for i8 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i16 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i128 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for bool { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for char { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for serde_json::Value { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + Ok(OneOrMany::one( + serde_json::to_string(self).map_err(ExtractEmbeddingFieldsError::new)?, + )) + } +} + +impl ExtractEmbeddingFields for Vec { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result, Self::Error> { + let items = self + .iter() + .map(|item| item.extract_embedding_fields()) + .collect::, _>>() + .map_err(ExtractEmbeddingFieldsError::new)?; + + OneOrMany::merge(items).map_err(ExtractEmbeddingFieldsError::new) + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index b8ad9b62..f5e9ede5 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; -pub mod embeddable; +pub mod extract_embedding_fields; pub mod embedding; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embeddable::Embeddable; +pub use extract_embedding_fields::ExtractEmbeddingFields; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index c7c23b87..7550038f 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,7 +1,7 @@ -use crate::{tool::ToolEmbeddingDyn, Embeddable, OneOrMany}; +use crate::{tool::ToolEmbeddingDyn, ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -use super::embeddable::EmbeddableError; +use super::extract_embedding_fields::ExtractEmbeddingFieldsError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] @@ -11,11 +11,11 @@ pub struct EmbeddableTool { pub embedding_docs: Vec, } -impl Embeddable for EmbeddableTool { - type Error = EmbeddableError; +impl ExtractEmbeddingFields for EmbeddableTool { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result, Self::Error> { - OneOrMany::many(self.embedding_docs.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result, Self::Error> { + OneOrMany::many(self.embedding_docs.clone()).map_err(ExtractEmbeddingFieldsError::new) } } @@ -81,10 +81,10 @@ impl EmbeddableTool { /// assert_eq!(tool.name, "nothing".to_string()); /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); /// ``` - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { Ok(EmbeddableTool { name: tool.name(), - context: tool.context().map_err(EmbeddableError::new)?, + context: tool.context().map_err(ExtractEmbeddingFieldsError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 2ef24051..5383b34e 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,8 +78,8 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::Embeddable; +pub use embeddings::ExtractEmbeddingFields; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] -pub use rig_derive::Embeddable; +pub use rig_derive::ExtractEmbeddingFields; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 8f8eefd4..a6d8f00b 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; @@ -85,7 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings( + pub fn embeddings( &self, model: &str, input_type: &str, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index b20df22f..0bfeac3c 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,7 +121,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e92896b8..528faba5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, + embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, tool::EmbeddableTool}, }; #[derive(Debug, thiserror::Error)] @@ -330,7 +330,7 @@ impl ToolSet { /// Convert tools in self to objects of type EmbeddableTool. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result, EmbeddableError> { + pub fn embedabble_tools(&self) -> Result, ExtractEmbeddingFieldsError> { self.tools .values() .filter_map(|tool_type| { diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index cbc76c80..5f8891ff 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -1,14 +1,14 @@ -use rig::embeddings::embeddable::EmbeddableError; -use rig::{Embeddable, OneOrMany}; +use rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -fn serialize(definition: Definition) -> Result, EmbeddableError> { +fn serialize(definition: Definition) -> Result, ExtractEmbeddingFieldsError> { Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(EmbeddableError::new)?, + serde_json::to_string(&definition).map_err(ExtractEmbeddingFieldsError::new)?, )) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition { id: String, word: String, @@ -41,7 +41,7 @@ fn test_custom_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one( "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -49,7 +49,7 @@ fn test_custom_embed() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition2 { id: String, #[embed] @@ -76,17 +76,17 @@ fn test_custom_and_basic_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap().first(), + fake_definition.extract_embedding_fields().unwrap().first(), "house".to_string() ); assert_eq!( - fake_definition.embeddable().unwrap().rest(), + fake_definition.extract_embedding_fields().unwrap().rest(), vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition3 { id: String, word: String, @@ -109,12 +109,12 @@ fn test_single_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one(definition) ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company { id: String, company: String, @@ -132,7 +132,7 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let result = company.embeddable().unwrap(); + let result = company.extract_embedding_fields().unwrap(); assert_eq!( result, @@ -153,7 +153,7 @@ fn test_multiple_embed_strings() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company2 { id: String, #[embed] @@ -173,7 +173,7 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); assert_eq!( - company.embeddable().unwrap(), + company.extract_embedding_fields().unwrap(), OneOrMany::many(vec![ "Google".to_string(), "25".to_string(), diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index d6e02a5a..1a9089a5 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::Embedding; -use rig::{Embeddable, OneOrMany}; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Deserialize; -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] pub struct FakeDefinition { pub id: String, #[embed] diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index b095c060..cc16d4bc 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,7 +3,7 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::Embeddable; +use rig::ExtractEmbeddingFields; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; @@ -11,7 +11,7 @@ use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] struct FakeDefinition { #[serde(rename = "_id")] id: String, From ebc6b81e0190adc25a6622fa8c620a9962eaf78e Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 23 Oct 2024 10:38:51 -0400 Subject: [PATCH 4/6] feat: rename macro files, cargo fmt --- rig-core/Cargo.toml | 2 +- .../src/{embeddable.rs => extract_embedding_fields.rs} | 4 ++-- rig-core/rig-core-derive/src/lib.rs | 4 ++-- rig-core/src/embeddings/builder.rs | 9 ++++++--- rig-core/src/embeddings/mod.rs | 4 ++-- rig-core/src/providers/openai.rs | 5 ++++- ...ddable_macro.rs => extract_embedding_fields_macro.rs} | 0 7 files changed, 17 insertions(+), 11 deletions(-) rename rig-core/rig-core-derive/src/{embeddable.rs => extract_embedding_fields.rs} (95%) rename rig-core/tests/{embeddable_macro.rs => extract_embedding_fields_macro.rs} (100%) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index ff2df2da..ed666fe4 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,7 +35,7 @@ tokio-test = "0.4.4" derive = ["dep:rig-derive"] [[test]] -name = "embeddable_macro" +name = "extract_embedding_fields_macro" required-features = ["derive"] [[example]] diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/extract_embedding_fields.rs similarity index 95% rename from rig-core/rig-core-derive/src/embeddable.rs rename to rig-core/rig-core-derive/src/extract_embedding_fields.rs index 27dac489..5c21f6b2 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/extract_embedding_fields.rs @@ -45,7 +45,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu // Note: `ExtractEmbeddingFields` trait is imported with the macro. impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { - type Error = rig::embeddings::embeddable::ExtractEmbeddingFieldsError; + type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; fn extract_embedding_fields(&self) -> Result, Self::Error> { #target_stream; @@ -53,7 +53,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu rig::OneOrMany::merge( embed_targets.into_iter() .collect::, _>>()? - ).map_err(rig::embeddings::embeddable::ExtractEmbeddingFieldsError::new) + ).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new) } } }; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 042f7ca9..8ad69a65 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -4,7 +4,7 @@ use syn::{parse_macro_input, DeriveInput}; mod basic; mod custom; -mod embeddable; +mod extract_embedding_fields; pub(crate) const EMBED: &str = "embed"; @@ -15,7 +15,7 @@ pub(crate) const EMBED: &str = "embed"; pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - embeddable::expand_derive_embedding(&mut input) + extract_embedding_fields::expand_derive_embedding(&mut input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index b6138ef5..e884dc69 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -6,7 +6,7 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use crate::{ - embeddings::{ExtractEmbeddingFields, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{Embedding, EmbeddingError, EmbeddingModel, ExtractEmbeddingFields}, OneOrMany, }; @@ -179,7 +179,9 @@ impl Embeddi #[cfg(test)] mod tests { use crate::{ - embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel}, + embeddings::{ + extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel, + }, ExtractEmbeddingFields, }; @@ -219,7 +221,8 @@ mod tests { type Error = ExtractEmbeddingFieldsError; fn extract_embedding_fields(&self) -> Result, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()).map_err(ExtractEmbeddingFieldsError::new) + crate::OneOrMany::many(self.definitions.clone()) + .map_err(ExtractEmbeddingFieldsError::new) } } diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index f5e9ede5..37323cf5 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; -pub mod extract_embedding_fields; pub mod embedding; +pub mod extract_embedding_fields; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use extract_embedding_fields::ExtractEmbeddingFields; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use extract_embedding_fields::ExtractEmbeddingFields; pub use tool::EmbeddableTool; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 0bfeac3c..789c5282 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/extract_embedding_fields_macro.rs similarity index 100% rename from rig-core/tests/embeddable_macro.rs rename to rig-core/tests/extract_embedding_fields_macro.rs From 5c2d451c0e536670a28bd7399b5ff323f509620c Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 24 Oct 2024 09:27:12 -0400 Subject: [PATCH 5/6] PR; update docstrings, update `add_documents_with_id` function --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 12 +----- rig-core/src/vector_store/in_memory_store.rs | 20 ++------- rig-lancedb/src/lib.rs | 34 +++++++-------- rig-mongodb/src/lib.rs | 45 ++++++++++---------- 9 files changed, 51 insertions(+), 70 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 723bfada..576491d5 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -252,7 +252,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "name")? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ab2f7767..9829089f 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -56,7 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index c3a2c251..c140da15 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -161,7 +161,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "name")? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 36bb8d7e..4777b42c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -57,7 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(model); let results = index diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 003d39f5..6d966004 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -58,7 +58,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(search_model); let results = index diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index e884dc69..b1b310bb 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -73,7 +73,6 @@ impl EmbeddingsBuilder { /// /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// -/// # tokio_test::block_on(async { /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ /// FakeDefinition { @@ -100,16 +99,9 @@ impl EmbeddingsBuilder { /// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() /// ] /// }, -/// ]) -/// .unwrap() +/// ])? /// .build() -/// .await -/// .unwrap(); -/// -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc0" && embeddings.len() == 2), true); -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc1" && embeddings.len() == 2), true); -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); -/// }) +/// .await?; /// ``` impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 5ab9d6e7..f4f067fe 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,28 +76,16 @@ impl InMemoryVectorStore { Ok(self) } - /// Add documents to the store. Define the name of the field in the document that contains the id. + /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. /// Returns the store with the added documents. pub fn add_documents_with_id( mut self, documents: Vec<(D, OneOrMany)>, - id_field: &str, + id_f: fn(&D) -> String, ) -> Result { for (doc, embeddings) in documents { - if let serde_json::Value::Object(o) = - serde_json::to_value(&doc).map_err(VectorStoreError::JsonError)? - { - match o.get(id_field) { - Some(serde_json::Value::String(s)) => { - self.embeddings.insert(s.clone(), (doc, embeddings)); - } - _ => { - return Err(VectorStoreError::MissingIdError(format!( - "Document does not have a field {id_field}" - ))); - } - } - }; + let id = id_f(&doc); + self.embeddings.insert(id, (doc, embeddings)); } Ok(self) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index eaaffbe3..46a6ef83 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -26,9 +26,9 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// -/// fn create_index(table: lancedb::Table, model: EmbeddingModel) { -/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; -/// } +/// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. +/// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` pub struct LanceDbVectorIndex { /// Defines which model is used to generate embeddings for the vector store. @@ -181,14 +181,14 @@ impl VectorStoreIndex for LanceDbVectorIndex /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// - /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { - /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// - /// // Query the index - /// let result = vector_store_index - /// .top_n::("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// let result = vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n Deserialize<'a> + Send>( &self, @@ -227,14 +227,14 @@ impl VectorStoreIndex for LanceDbVectorIndex /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example /// ``` - /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { - /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// - /// // Query the index - /// let result = vector_store_index - /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n_ids( &self, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index a803d385..50f67b11 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -185,18 +185,19 @@ impl VectorStoreIndex /// definition: String, /// } /// - /// fn execute_search(collection: mongodb::Collection, model: EmbeddingModel) { - /// let vector_store_index = MongoDbVectorStore::new(collection).index( - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); + /// let collection: collection: mongodb::Collection = \* ... \*; // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. /// - /// // Query the index - /// vector_store_index - /// .top_n::("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n Deserialize<'a> + Send>( &self, @@ -252,18 +253,18 @@ impl VectorStoreIndex /// embedding: Vec, /// } /// - /// fn execute_search(collection: mongodb::Collection, model: EmbeddingModel) { - /// let vector_store_index = MongoDbVectorStore::new(collection).index( - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); + /// let collection: collection: mongodb::Collection = \* ... \*; // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); /// - /// // Query the index - /// vector_store_index - /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n_ids( &self, From 55b42d86b5222a380ba378192af7ad48908fd13d Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 24 Oct 2024 09:49:25 -0400 Subject: [PATCH 6/6] doc: fix doc linting --- rig-lancedb/src/lib.rs | 12 ++++++------ rig-mongodb/src/lib.rs | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 46a6ef83..2eea2357 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -26,8 +26,8 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// -/// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. -/// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. +/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` pub struct LanceDbVectorIndex { @@ -181,8 +181,8 @@ impl VectorStoreIndex for LanceDbVectorIndex /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// - /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. - /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index @@ -227,8 +227,8 @@ impl VectorStoreIndex for LanceDbVectorIndex /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example /// ``` - /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. - /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 50f67b11..655f3939 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -24,13 +24,13 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// embedding: Vec, /// } /// -/// fn create_index(collection: mongodb::Collection, model: EmbeddingModel) { -/// let index = MongoDbVectorStore::new(collection).index( -/// model, -/// "vector_index", // <-- replace with the name of the index in your mongodb collection. -/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. -/// ); -/// } +/// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. +/// let index = MongoDbVectorStore::new(collection).index( +/// model, +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. +/// ); /// ``` pub struct MongoDbVectorStore { collection: mongodb::Collection, @@ -185,8 +185,8 @@ impl VectorStoreIndex /// definition: String, /// } /// - /// let collection: collection: mongodb::Collection = \* ... \*; // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// /// let vector_store_index = MongoDbVectorStore::new(collection).index( /// model, @@ -253,8 +253,8 @@ impl VectorStoreIndex /// embedding: Vec, /// } /// - /// let collection: collection: mongodb::Collection = \* ... \*; // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// let vector_store_index = MongoDbVectorStore::new(collection).index( /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection.