Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(embeddings): embed trait definition #89

Merged
merged 9 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ tokio-test = "0.4.4"
derive = ["dep:rig-derive"]

[[test]]
name = "extract_embedding_fields_macro"
name = "embed_macro"
required-features = ["derive"]

[[example]]
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(toolset.embedabble_tools()?)?
.documents(toolset.schemas()?)?
.build()
.await?;

Expand Down
8 changes: 4 additions & 4 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
ExtractEmbeddingFields,
Embed,
};
use serde::Serialize;

// Shape of data that needs to be RAG'ed.
// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition`
// Data to be RAGged.
// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition`
// and tag that field with `#[embed]`.
#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)]
#[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
#[embed]
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> {
.build();

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(toolset.embedabble_tools()?)?
.documents(toolset.schemas()?)?
.build()
.await?;

Expand Down
4 changes: 2 additions & 2 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
ExtractEmbeddingFields,
Embed,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
Expand Down
4 changes: 2 additions & 2 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
ExtractEmbeddingFields,
Embed,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
Expand Down
6 changes: 3 additions & 3 deletions rig-core/rig-core-derive/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use syn::{parse_quote, Attribute, DataStruct, Meta};

use crate::EMBED;

/// Finds and returns fields with simple #[embed] attribute tags only.
/// Finds and returns fields with simple `#[embed]` attribute tags only.
pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = &syn::Field> {
data_struct.fields.iter().filter(|field| {
field.attrs.iter().any(|attribute| match attribute {
Expand All @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item
})
}

/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the ExtractEmbeddingFields trait.
/// Adds bounds to where clause that force all fields tagged with `#[embed]` to implement the `Embed` trait.
pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) {
let where_clause = generics.make_where_clause();

where_clause.predicates.push(parse_quote! {
#field_type: ExtractEmbeddingFields
#field_type: Embed
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
let (basic_targets, basic_target_size) = data_struct.basic(generics);
let (custom_targets, custom_target_size) = data_struct.custom()?;

// If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream.
// ie. do not implement `ExtractEmbeddingFields` trait for the struct.
// If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream.
// ie. do not implement `Embed` trait for the struct.
if basic_target_size + custom_target_size == 0 {
return Err(syn::Error::new_spanned(
name,
Expand All @@ -27,33 +27,28 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
}

quote! {
let mut embed_targets = #basic_targets;
embed_targets.extend(#custom_targets)
#basic_targets;
#custom_targets;
}
}
_ => {
return Err(syn::Error::new_spanned(
input,
"ExtractEmbeddingFields derive macro should only be used on structs",
"Embed derive macro should only be used on structs",
))
}
};

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let gen = quote! {
// Note: `ExtractEmbeddingFields` trait is imported with the macro.
// Note: `Embed` trait is imported with the macro.

impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause {
type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError;

fn extract_embedding_fields(&self) -> Result<rig::OneOrMany<String>, Self::Error> {
impl #impl_generics Embed for #name #ty_generics #where_clause {
fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> {
#target_stream;

rig::OneOrMany::merge(
embed_targets.into_iter()
.collect::<Result<Vec<_>, _>>()?
).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new)
Ok(())
}
}
};
Expand All @@ -62,17 +57,17 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
}

trait StructParser {
// Handles fields tagged with #[embed]
// Handles fields tagged with `#[embed]`
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize);

// Handles fields tagged with #[embed(embed_with = "...")]
// Handles fields tagged with `#[embed(embed_with = "...")]`
fn custom(&self) -> syn::Result<(TokenStream, usize)>;
}

impl StructParser for DataStruct {
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) {
let embed_targets = basic_embed_fields(self)
// Iterate over every field tagged with #[embed]
// Iterate over every field tagged with `#[embed]`
.map(|field| {
add_struct_bounds(generics, &field.ty);

Expand All @@ -84,50 +79,32 @@ impl StructParser for DataStruct {
})
.collect::<Vec<_>>();

if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets.extract_embedding_fields()),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
}
(
quote! {
#(#embed_targets.embed(embedder)?;)*
},
embed_targets.len(),
)
}

fn custom(&self) -> syn::Result<(TokenStream, usize)> {
let embed_targets = custom_embed_fields(self)?
// Iterate over every field tagged with #[embed(embed_with = "...")]
// Iterate over every field tagged with `#[embed(embed_with = "...")]`
.into_iter()
.map(|(field, custom_func_path)| {
let field_name = &field.ident;

quote! {
#custom_func_path(self.#field_name.clone())
#custom_func_path(embedder, self.#field_name.clone())?;
}
})
.collect::<Vec<_>>();

Ok(if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
})
Ok((
quote! {
#(#embed_targets)*
},
embed_targets.len(),
))
}
}
12 changes: 6 additions & 6 deletions rig-core/rig-core-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use syn::{parse_macro_input, DeriveInput};

mod basic;
mod custom;
mod extract_embedding_fields;
mod embed;

pub(crate) const EMBED: &str = "embed";

// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro
// https://doc.rust-lang.org/reference/procedural-macros.html

#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))]
/// References:
/// <https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
/// <https://doc.rust-lang.org/reference/procedural-macros.html>
#[proc_macro_derive(Embed, attributes(embed))]
pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as DeriveInput);

extract_embedding_fields::expand_derive_embedding(&mut input)
embed::expand_derive_embedding(&mut input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
2 changes: 1 addition & 1 deletion rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub enum CompletionError {

/// Error building the completion request
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync>),
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),

/// Error parsing the completion response
#[error("ResponseError: {0}")]
Expand Down
Loading
Loading