From 1b30e230ff4cad930259aa43a68fefb4e0fda566 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Tue, 17 Dec 2024 12:19:31 -0800 Subject: [PATCH 1/5] partial impl --- rig-core/src/completion.rs | 65 ++++++++++++++++++++++++++++++++++--- rig-core/src/one_or_many.rs | 4 ++- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index f13f316b..886004db 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -67,7 +67,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{json_utils, tool::ToolSetError}; +use crate::{json_utils, tool::ToolSetError, OneOrMany}; // Errors #[derive(Debug, Error)] @@ -106,10 +106,65 @@ pub enum PromptError { // Request models // ================================================================ #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Message { - /// "system", "user", or "assistant" - pub role: String, - pub content: String, +#[serde(tag = "role", rename_all = "lowercase")] +pub enum Message { + System { + content: OneOrMany, + raw_message: T, + }, + User { + content: OneOrMany, + raw_message: T, + }, + Assistant { + content: OneOrMany, + raw_message: T, + }, + Tool { + content: String, + tool_call_id: String, + raw_message: T, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum UserContent { + #[serde(rename = "text")] + Text { + text: String, + }, + #[serde(rename = "image_url")] + Image { + image_url: String, + detail: ImageDetail, + }, + #[serde(rename = "input_audio")] + Audio { + data: String, + format: String, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ImageDetail { + Low, + High, + Auto, +} + +impl std::str::FromStr for ImageDetail { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "low" => Ok(ImageDetail::Low), + "high" => Ok(ImageDetail::High), + "auto" => Ok(ImageDetail::Auto), + _ => Err(()), + } + } } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 64584603..e2431f75 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -1,9 +1,11 @@ +use serde::{Deserialize, Serialize}; + /// Struct containing either a single item or a list of items of type T. /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. /// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] pub struct OneOrMany { /// First item in the list. first: T, From 2454659ed085c77f3f9fbc962beeddcd8e53666d Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Fri, 27 Dec 2024 16:19:06 -0600 Subject: [PATCH 2/5] partial impl 2 --- rig-core/src/completion.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 886004db..c49df034 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -117,33 +117,44 @@ pub enum Message { raw_message: T, }, Assistant { + refusal: Option, content: OneOrMany, + tool_calls: OneOrMany, raw_message: T, }, Tool { + id: String, content: String, - tool_call_id: String, raw_message: T, }, } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, + pub function: ToolFunction, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunction { + pub name: String, + pub arguments: String, +} + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(tag = "type")] pub enum UserContent { #[serde(rename = "text")] - Text { - text: String, - }, + Text { text: String }, #[serde(rename = "image_url")] Image { image_url: String, detail: ImageDetail, }, #[serde(rename = "input_audio")] - Audio { - data: String, - format: String, - }, + Audio { data: String, format: String }, } #[derive(Clone, Debug, Deserialize, Serialize)] From 7aa449a3c30fb6839cf75d60b4f27c7972188824 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Thu, 9 Jan 2025 23:02:05 -0800 Subject: [PATCH 3/5] partial impl 3 --- rig-core/src/completion.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index c49df034..1f7ffc99 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -107,25 +107,21 @@ pub enum PromptError { // ================================================================ #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(tag = "role", rename_all = "lowercase")] -pub enum Message { +pub enum Message { System { content: OneOrMany, - raw_message: T, }, User { content: OneOrMany, - raw_message: T, }, Assistant { refusal: Option, content: OneOrMany, tool_calls: OneOrMany, - raw_message: T, }, Tool { id: String, content: String, - raw_message: T, }, } From 63b631480dabd202c071446927f783108a9c83df Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Mon, 13 Jan 2025 14:03:30 -0800 Subject: [PATCH 4/5] partial impl 4 --- Cargo.lock | 2 +- rig-core/src/agent.rs | 160 ++++++++++++++++++++---------------- rig-core/src/cli_chatbot.rs | 5 +- rig-core/src/completion.rs | 90 ++------------------ rig-core/src/lib.rs | 1 + rig-core/src/message.rs | 128 +++++++++++++++++++++++++++++ 6 files changed, 231 insertions(+), 155 deletions(-) create mode 100644 rig-core/src/message.rs diff --git a/Cargo.lock b/Cargo.lock index 62e304f9..66e040b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 2cbe5fbd..2a571524 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -113,8 +113,9 @@ use futures::{stream, StreamExt, TryStreamExt}; use crate::{ completion::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, - CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError, + CompletionResponse, Document, ModelChoice, Prompt, PromptError, }, + message::Message, tool::{Tool, ToolSet}, vector_store::{VectorStoreError, VectorStoreIndexDyn}, }; @@ -165,94 +166,109 @@ pub struct Agent { impl Completion for Agent { async fn completion( &self, - prompt: &str, + prompt: impl Into, chat_history: Vec, ) -> Result, CompletionError> { - let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, id, doc)| { - // Pretty print the document if possible for better readability - let text = serde_json::to_string_pretty(&doc) - .unwrap_or_else(|_| doc.to_string()); - - Document { - id, - text, - additional_props: HashMap::new(), - } - }) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - acc.extend(docs); - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index)| async { - Ok::<_, VectorStoreError>( - index - .top_n_ids(prompt, *num_sample) - .await? - .into_iter() - .map(|(_, id)| id) - .collect::>(), - ) - }) - .try_fold(vec![], |mut acc, docs| async { - for doc in docs { - if let Some(tool) = self.tools.get(&doc) { - acc.push(tool.definition(prompt.into()).await) - } else { - tracing::warn!("Tool implementation not found in toolset: {}", doc); - } - } - Ok(acc) - }) - .await - .map_err(|e| CompletionError::RequestError(Box::new(e)))?; - - let static_tools = stream::iter(self.static_tools.iter()) - .filter_map(|toolname| async move { - if let Some(tool) = self.tools.get(toolname) { - Some(tool.definition(prompt.into()).await) - } else { - tracing::warn!("Tool implementation not found in toolset: {}", toolname); - None - } - }) - .collect::>() - .await; - - Ok(self + let completion_request = self .model .completion_request(prompt) .preamble(self.preamble.clone()) .messages(chat_history) - .documents([self.static_context.clone(), dynamic_context].concat()) - .tools([static_tools.clone(), dynamic_tools].concat()) .temperature_opt(self.temperature) .max_tokens_opt(self.max_tokens) - .additional_params_opt(self.additional_params.clone())) + .additional_params_opt(self.additional_params.clone()); + + let agent = match prompt.into().rag_text() { + Some(text) => { + let dynamic_context = stream::iter(self.dynamic_context.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n(&text, *num_sample) + .await? + .into_iter() + .map(|(_, id, doc)| { + // Pretty print the document if possible for better readability + let text = serde_json::to_string_pretty(&doc) + .unwrap_or_else(|_| doc.to_string()); + + Document { + id, + text, + additional_props: HashMap::new(), + } + }) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + acc.extend(docs); + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let dynamic_tools = stream::iter(self.dynamic_tools.iter()) + .then(|(num_sample, index)| async { + Ok::<_, VectorStoreError>( + index + .top_n_ids(&text, *num_sample) + .await? + .into_iter() + .map(|(_, id)| id) + .collect::>(), + ) + }) + .try_fold(vec![], |mut acc, docs| async { + for doc in docs { + if let Some(tool) = self.tools.get(&doc) { + acc.push(tool.definition(text.into()).await) + } else { + tracing::warn!("Tool implementation not found in toolset: {}", doc); + } + } + Ok(acc) + }) + .await + .map_err(|e| CompletionError::RequestError(Box::new(e)))?; + + let static_tools = stream::iter(self.static_tools.iter()) + .filter_map(|toolname| async move { + if let Some(tool) = self.tools.get(toolname) { + Some(tool.definition(text.into()).await) + } else { + tracing::warn!( + "Tool implementation not found in toolset: {}", + toolname + ); + None + } + }) + .collect::>() + .await; + + completion_request + .documents([self.static_context.clone(), dynamic_context].concat()) + .tools([static_tools.clone(), dynamic_tools].concat()) + } + None => completion_request, + }; + Ok(agent) } } impl Prompt for Agent { - async fn prompt(&self, prompt: &str) -> Result { + async fn prompt(&self, prompt: impl Into) -> Result { self.chat(prompt, vec![]).await } } impl Chat for Agent { - async fn chat(&self, prompt: &str, chat_history: Vec) -> Result { + async fn chat( + &self, + prompt: impl Into, + chat_history: Vec, + ) -> Result { match self.completion(prompt, chat_history).await?.send().await? { CompletionResponse { choice: ModelChoice::Message(msg), diff --git a/rig-core/src/cli_chatbot.rs b/rig-core/src/cli_chatbot.rs index df33fa8e..10914800 100644 --- a/rig-core/src/cli_chatbot.rs +++ b/rig-core/src/cli_chatbot.rs @@ -1,6 +1,9 @@ use std::io::{self, Write}; -use crate::completion::{Chat, Message, PromptError}; +use crate::{ + completion::{Chat, PromptError}, + message::Message, +}; /// Utility function to create a simple REPL CLI chatbot from a type that implements the /// `Chat` trait. diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 1f7ffc99..b85169ee 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -67,7 +67,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{json_utils, tool::ToolSetError, OneOrMany}; +use crate::{json_utils, message::Message, tool::ToolSetError}; // Errors #[derive(Debug, Error)] @@ -102,78 +102,6 @@ pub enum PromptError { ToolError(#[from] ToolSetError), } -// ================================================================ -// Request models -// ================================================================ -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(tag = "role", rename_all = "lowercase")] -pub enum Message { - System { - content: OneOrMany, - }, - User { - content: OneOrMany, - }, - Assistant { - refusal: Option, - content: OneOrMany, - tool_calls: OneOrMany, - }, - Tool { - id: String, - content: String, - }, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub tool_type: String, - pub function: ToolFunction, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunction { - pub name: String, - pub arguments: String, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(tag = "type")] -pub enum UserContent { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - Image { - image_url: String, - detail: ImageDetail, - }, - #[serde(rename = "input_audio")] - Audio { data: String, format: String }, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum ImageDetail { - Low, - High, - Auto, -} - -impl std::str::FromStr for ImageDetail { - type Err = (); - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "low" => Ok(ImageDetail::Low), - "high" => Ok(ImageDetail::High), - "auto" => Ok(ImageDetail::Auto), - _ => Err(()), - } - } -} - #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Document { pub id: String, @@ -226,7 +154,7 @@ pub trait Prompt: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn prompt( &self, - prompt: &str, + prompt: impl Into, ) -> impl std::future::Future> + Send; } @@ -242,7 +170,7 @@ pub trait Chat: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn chat( &self, - prompt: &str, + prompt: impl Into, chat_history: Vec, ) -> impl std::future::Future> + Send; } @@ -302,15 +230,15 @@ pub trait CompletionModel: Clone + Send + Sync { + Send; /// Generates a completion request builder for the given `prompt`. - fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder { - CompletionRequestBuilder::new(self.clone(), prompt.to_string()) + fn completion_request(&self, prompt: impl Into) -> CompletionRequestBuilder { + CompletionRequestBuilder::new(self.clone(), prompt) } } /// Struct representing a general completion request that can be sent to a completion model provider. pub struct CompletionRequest { /// The prompt to be sent to the completion model provider - pub prompt: String, + pub prompt: Message, /// The preamble to be sent to the completion model provider pub preamble: Option, /// The chat history to be sent to the completion model provider @@ -391,7 +319,7 @@ impl CompletionRequest { /// Instead, use the [CompletionModel::completion_request] method. pub struct CompletionRequestBuilder { model: M, - prompt: String, + prompt: Message, preamble: Option, chat_history: Vec, documents: Vec, @@ -402,10 +330,10 @@ pub struct CompletionRequestBuilder { } impl CompletionRequestBuilder { - pub fn new(model: M, prompt: String) -> Self { + pub fn new(model: M, prompt: impl Into) -> Self { Self { model, - prompt, + prompt: prompt.into(), preamble: None, chat_history: Vec::new(), documents: Vec::new(), diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index f1b5427b..8daa3c8c 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -85,6 +85,7 @@ pub mod embeddings; pub mod extractor; pub(crate) mod json_utils; pub mod loaders; +pub mod message; pub mod one_or_many; pub mod providers; pub mod tool; diff --git a/rig-core/src/message.rs b/rig-core/src/message.rs new file mode 100644 index 00000000..c1d056f2 --- /dev/null +++ b/rig-core/src/message.rs @@ -0,0 +1,128 @@ +use crate::OneOrMany; +use serde::{Deserialize, Serialize}; + +// ================================================================ +// Request models +// ================================================================ +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum Message { + System { + content: OneOrMany, + }, + User { + content: OneOrMany, + }, + Assistant { + content: OneOrMany, + tool_calls: OneOrMany, + }, + Tool { + id: String, + content: String, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, + pub function: ToolFunction, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolFunction { + pub name: String, + pub arguments: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum UserContent { + Text { + text: String, + }, + Image { + data: String, + format: ContentFormat, + detail: ImageDetail, + r#media_type: MediaType, + }, + Document { + data: String, + format: ContentFormat, + r#media_type: String, + }, + Audio { + data: String, + format: ContentFormat, + r#media_type: String, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ContentFormat { + Base64, + String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum MediaType { + #[serde(rename = "application/pdf")] + ApplicationPdf, + #[serde(rename = "image/jpeg")] + ImageJpeg, + #[serde(rename = "image/png")] + ImagePng, + #[serde(rename = "image/gif")] + ImageGif, + #[serde(rename = "image/webp")] + ImageWebp, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ImageDetail { + Low, + High, + Auto, +} + +impl std::str::FromStr for ImageDetail { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "low" => Ok(ImageDetail::Low), + "high" => Ok(ImageDetail::High), + "auto" => Ok(ImageDetail::Auto), + _ => Err(()), + } + } +} + +impl From for Message { + fn from(text: String) -> Self { + Message::User { + content: OneOrMany::::one(UserContent::Text { text }), + } + } +} + +impl Message { + pub fn rag_text(&self) -> Option { + match self { + Message::User { content } => { + if let UserContent::Text { text }= content.first() { + Some(text.clone()) + } else { + None + } + } + _ => None, + } + } +} From f24ce60ab4b5be3bc1ada238873671cbd3437dcf Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Wed, 15 Jan 2025 22:04:15 -0800 Subject: [PATCH 5/5] partial impl 5 --- rig-core/src/agent.rs | 11 +- rig-core/src/completion.rs | 69 +++-- rig-core/src/message.rs | 77 ++++-- rig-core/src/one_or_many.rs | 10 + .../src/providers/anthropic/completion.rs | 244 +++++++++++++++--- rig-core/src/providers/cohere.rs | 12 +- rig-core/src/providers/perplexity.rs | 94 +++++-- 7 files changed, 398 insertions(+), 119 deletions(-) diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 2a571524..83efee8c 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -166,9 +166,12 @@ pub struct Agent { impl Completion for Agent { async fn completion( &self, - prompt: impl Into, + prompt: impl Into + Send, chat_history: Vec, ) -> Result, CompletionError> { + let prompt = prompt.into(); + let rag_text = prompt.rag_text().clone(); + let completion_request = self .model .completion_request(prompt) @@ -178,7 +181,7 @@ impl Completion for Agent { .max_tokens_opt(self.max_tokens) .additional_params_opt(self.additional_params.clone()); - let agent = match prompt.into().rag_text() { + let agent = match &rag_text { Some(text) => { let dynamic_context = stream::iter(self.dynamic_context.iter()) .then(|(num_sample, index)| async { @@ -258,7 +261,7 @@ impl Completion for Agent { } impl Prompt for Agent { - async fn prompt(&self, prompt: impl Into) -> Result { + async fn prompt(&self, prompt: impl Into + Send) -> Result { self.chat(prompt, vec![]).await } } @@ -266,7 +269,7 @@ impl Prompt for Agent { impl Chat for Agent { async fn chat( &self, - prompt: impl Into, + prompt: impl Into + Send, chat_history: Vec, ) -> Result { match self.completion(prompt, chat_history).await?.send().await? { diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index b85169ee..1eb1db69 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -67,7 +67,11 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{json_utils, message::Message, tool::ToolSetError}; +use crate::{ + json_utils, + message::{Message, UserContent}, + tool::ToolSetError, +}; // Errors #[derive(Debug, Error)] @@ -154,7 +158,7 @@ pub trait Prompt: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn prompt( &self, - prompt: impl Into, + prompt: impl Into + Send, ) -> impl std::future::Future> + Send; } @@ -170,7 +174,7 @@ pub trait Chat: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn chat( &self, - prompt: impl Into, + prompt: impl Into + Send, chat_history: Vec, ) -> impl std::future::Future> + Send; } @@ -190,7 +194,7 @@ pub trait Completion { /// contain the `preamble` provided when creating the agent. fn completion( &self, - prompt: &str, + prompt: impl Into + Send, chat_history: Vec, ) -> impl std::future::Future, CompletionError>> + Send; } @@ -256,20 +260,26 @@ pub struct CompletionRequest { } impl CompletionRequest { - pub(crate) fn prompt_with_context(&self) -> String { - if !self.documents.is_empty() { - format!( - "\n{}\n\n{}", - self.documents + pub(crate) fn prompt_with_context(&self) -> Message { + let mut new_prompt = self.prompt.clone(); + if let Message::User { ref mut content } = new_prompt { + if !self.documents.is_empty() { + let attachments = self + .documents .iter() .map(|doc| doc.to_string()) .collect::>() - .join(""), - self.prompt - ) - } else { - self.prompt.clone() + .join(""); + let formatted_content = format!("\n{}", attachments); + content.insert( + 0, + UserContent::Text { + text: formatted_content, + }, + ); + } } + new_prompt } } @@ -465,6 +475,8 @@ impl CompletionRequestBuilder { #[cfg(test)] mod tests { + use crate::OneOrMany; + use super::*; #[test] @@ -515,7 +527,7 @@ mod tests { }; let request = CompletionRequest { - prompt: "What is the capital of France?".to_string(), + prompt: "What is the capital of France?".into(), preamble: None, chat_history: Vec::new(), documents: vec![doc1, doc2], @@ -525,14 +537,25 @@ mod tests { additional_params: None, }; - let expected = concat!( - "\n", - "\nDocument 1 text.\n\n", - "\nDocument 2 text.\n\n", - "\n\n", - "What is the capital of France?" - ) - .to_string(); + let expected = Message::User { + content: OneOrMany::many(vec![ + UserContent::Text { + text: concat!( + "\n", + "\nDocument 1 text.\n\n", + "\nDocument 2 text.\n\n", + "\n" + ) + .to_string(), + }, + UserContent::Text { + text: "What is the capital of France?".to_string(), + }, + ]) + .expect("This has more than 1 item"), + }; + + request.prompt_with_context(); assert_eq!(request.prompt_with_context(), expected); } diff --git a/rig-core/src/message.rs b/rig-core/src/message.rs index c1d056f2..02dd17b3 100644 --- a/rig-core/src/message.rs +++ b/rig-core/src/message.rs @@ -1,21 +1,19 @@ use crate::OneOrMany; use serde::{Deserialize, Serialize}; +use thiserror::Error; // ================================================================ // Request models // ================================================================ -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum Message { - System { - content: OneOrMany, - }, User { content: OneOrMany, }, Assistant { - content: OneOrMany, - tool_calls: OneOrMany, + content: Vec, + tool_calls: Vec, }, Tool { id: String, @@ -23,7 +21,7 @@ pub enum Message { }, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub struct ToolCall { pub id: String, #[serde(rename = "type")] @@ -31,13 +29,13 @@ pub struct ToolCall { pub function: ToolFunction, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub struct ToolFunction { pub name: String, - pub arguments: String, + pub arguments: serde_json::Value, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] pub enum UserContent { Text { @@ -47,43 +45,50 @@ pub enum UserContent { data: String, format: ContentFormat, detail: ImageDetail, - r#media_type: MediaType, + media_type: ImageMediaType, }, Document { data: String, format: ContentFormat, - r#media_type: String, + media_type: DocumentMediaType, }, Audio { data: String, format: ContentFormat, - r#media_type: String, + media_type: AudioMediaType, }, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ContentFormat { Base64, String, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageMediaType { + JPEG, + PNG, + GIF, + WEBP, +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum DocumentMediaType { + PDF, +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "lowercase")] -pub enum MediaType { - #[serde(rename = "application/pdf")] - ApplicationPdf, - #[serde(rename = "image/jpeg")] - ImageJpeg, - #[serde(rename = "image/png")] - ImagePng, - #[serde(rename = "image/gif")] - ImageGif, - #[serde(rename = "image/webp")] - ImageWebp, +pub enum AudioMediaType { + WAV, + MP4, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ImageDetail { Low, @@ -112,11 +117,21 @@ impl From for Message { } } +impl From<&str> for Message { + fn from(text: &str) -> Self { + Message::User { + content: OneOrMany::::one(UserContent::Text { + text: text.to_owned(), + }), + } + } +} + impl Message { pub fn rag_text(&self) -> Option { match self { Message::User { content } => { - if let UserContent::Text { text }= content.first() { + if let UserContent::Text { text } = content.first() { Some(text.clone()) } else { None @@ -126,3 +141,9 @@ impl Message { } } } + +#[derive(Debug, Error)] +pub enum MessageError { + #[error("Message conversion error: {0}")] + ConversionError(String), +} diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index e2431f75..211c829d 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -34,6 +34,16 @@ impl OneOrMany { self.rest.push(item); } + /// After `OneOrMany` is created, insert an item of type T at an index. + pub fn insert(&mut self, index: usize, item: T) { + if index == 0 { + let old_first = std::mem::replace(&mut self.first, item); + self.rest.insert(0, old_first); + } else { + self.rest.insert(index - 1, item); + } + } + /// Length of all items in `OneOrMany`. pub fn len(&self) -> usize { 1 + self.rest.len() diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index cc9c84dc..1cb8b162 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -1,10 +1,9 @@ //! Anthropic completion api implementation -use std::iter; - use crate::{ completion::{self, CompletionError}, json_utils, + message::{self, MessageError}, }; use serde::{Deserialize, Serialize}; @@ -42,22 +41,6 @@ pub struct CompletionResponse { pub usage: Usage, } -#[derive(Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum Content { - String(String), - Text { - r#type: String, - text: String, - }, - ToolUse { - r#type: String, - id: String, - name: String, - input: serde_json::Value, - }, -} - #[derive(Debug, Deserialize, Serialize)] pub struct Usage { pub input_tokens: u64, @@ -103,12 +86,10 @@ impl TryFrom for completion::CompletionResponse std::prelude::v1::Result { match response.content.as_slice() { - [Content::String(text) | Content::Text { text, .. }, ..] => { - Ok(completion::CompletionResponse { - choice: completion::ModelChoice::Message(text.to_string()), - raw_response: response, - }) - } + [Content::Text { text, .. }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(text.to_string()), + raw_response: response, + }), [Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse { choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()), raw_response: response, @@ -123,18 +104,189 @@ impl TryFrom for completion::CompletionResponse for Message { - fn from(message: completion::Message) -> Self { - Self { - role: message.role, - content: message.content, +#[derive(Debug, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Content { + Text { + text: String, + }, + Image { + source: ImageSource, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: ToolResultContent, + is_error: bool, + }, + Document { + source: DocumentSource, + }, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolResultContent { + Text { text: String }, + Image(ImageSource), +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ImageSource { + pub data: String, + pub format: ImageFormat, + pub r#type: SourceType, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct DocumentSource { + pub data: String, + pub format: DocumentFormat, + pub r#type: SourceType, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ImageFormat { + #[serde(rename = "image/jpeg")] + JPEG, + #[serde(rename = "image/png")] + PNG, + #[serde(rename = "image/gif")] + GIF, + #[serde(rename = "image/webp")] + WEBP, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum DocumentFormat { + #[serde(rename = "application/pdf")] + PDF, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum SourceType { + BASE64, +} + +impl From for Content { + fn from(text: String) -> Self { + Content::Text { text } + } +} + +impl From for ToolResultContent { + fn from(text: String) -> Self { + ToolResultContent::Text { text } + } +} + +impl TryFrom for SourceType { + type Error = MessageError; + + fn try_from(format: message::ContentFormat) -> Result { + match format { + message::ContentFormat::Base64 => Ok(SourceType::BASE64), + message::ContentFormat::String => Err(MessageError::ConversionError( + "Image urls are not supported in Anthropic".to_owned(), + )), } } } +impl From for ImageFormat { + fn from(media_type: message::ImageMediaType) -> Self { + match media_type { + message::ImageMediaType::JPEG => ImageFormat::JPEG, + message::ImageMediaType::PNG => ImageFormat::PNG, + message::ImageMediaType::GIF => ImageFormat::GIF, + message::ImageMediaType::WEBP => ImageFormat::WEBP, + } + } +} + +impl TryFrom for Vec { + type Error = MessageError; + + fn try_from(message: message::Message) -> Result { + Ok(match message { + message::Message::User { content } => content + .into_iter() + .map(|content| match content { + message::UserContent::Text { text } => Ok(Content::Text { text }), + message::UserContent::Image { + data, + format, + media_type, + .. + } => { + let source = ImageSource { + data, + format: media_type.into(), + r#type: format.try_into()?, + }; + Ok(Content::Image { source }) + } + message::UserContent::Document { data, format, .. } => { + let source = DocumentSource { + data, + format: DocumentFormat::PDF, + r#type: format.try_into()?, + }; + Ok(Content::Document { source }) + } + message::UserContent::Audio { .. } => Err(MessageError::ConversionError( + "Audio is not supported in Anthropic".to_owned(), + )), + }) + .collect::, _>>()? + .into_iter() + .map(|content| Message { + role: "user".to_owned(), + content, + }) + .collect::>(), + + message::Message::Assistant { + content, + tool_calls, + } => content + .into_iter() + .map(|content| Message { + role: "assistant".to_owned(), + content: content.into(), + }) + .chain(tool_calls.into_iter().map(|tool_call| Message { + role: "assistant".to_owned(), + content: Content::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: tool_call.function.arguments, + }, + })) + .collect::>(), + + message::Message::Tool { id, content } => vec![Message { + role: "assistant".to_owned(), + content: Content::ToolResult { + tool_use_id: id, + content: content.into(), + is_error: false, + }, + }], + }) + } +} + #[derive(Clone)] pub struct CompletionModel { client: Client, @@ -174,8 +326,6 @@ impl completion::CompletionModel for CompletionModel { // specific requirements of each provider. For now, we just manually check while // building the request as a raw JSON document. - let prompt_with_context = completion_request.prompt_with_context(); - // Check if max_tokens is set, required for Anthropic if completion_request.max_tokens.is_none() { return Err(CompletionError::RequestError( @@ -183,17 +333,29 @@ impl completion::CompletionModel for CompletionModel { )); } + let prompt_message: Vec = completion_request + .prompt_with_context() + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.to_string().into()))?; + + let mut messages = completion_request + .chat_history + .into_iter() + .map(|message| { + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.to_string().into())) + }) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect::>(); + + messages.extend(prompt_message); + let mut request = json!({ "model": self.model, - "messages": completion_request - .chat_history - .into_iter() - .map(Message::from) - .chain(iter::once(Message { - role: "user".to_owned(), - content: prompt_with_context, - })) - .collect::>(), + "messages": messages, "max_tokens": completion_request.max_tokens, "system": completion_request.preamble.unwrap_or("".to_string()), }); diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 883204d4..16c60d80 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -11,11 +11,7 @@ use std::collections::HashMap; use crate::{ - agent::AgentBuilder, - completion::{self, CompletionError}, - embeddings::{self, EmbeddingError, EmbeddingsBuilder}, - extractor::ExtractorBuilder, - json_utils, Embed, + agent::AgentBuilder, completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, message, Embed }; use schemars::JsonSchema; @@ -480,8 +476,10 @@ pub struct Message { pub message: String, } -impl From for Message { - fn from(message: completion::Message) -> Self { +impl TryFrom for Message { + type Error = + fn from(message: message::Message) -> Self { + Self { role: match message.role.as_str() { "system" => "SYSTEM".to_owned(), diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index fa1e34fb..68513660 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -14,11 +14,13 @@ use crate::{ completion::{self, CompletionError}, extractor::ExtractorBuilder, json_utils, + message::{self, MessageError}, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; +use thiserror::Error; // ================================================================ // Main Cohere Client @@ -124,15 +126,23 @@ pub struct CompletionResponse { pub usage: Usage, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Message { - pub role: String, + pub role: Role, pub content: String, } +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + #[derive(Deserialize, Debug)] pub struct Delta { - pub role: String, + pub role: Role, pub content: String, } @@ -195,6 +205,52 @@ impl CompletionModel { } } +impl TryFrom for Vec { + type Error = MessageError; + + fn try_from(message: message::Message) -> Result { + Ok(match message { + message::Message::User { content } => content + .into_iter() + .map(|content| match content { + message::UserContent::Text { text } => Ok(Message { + role: Role::User, + content: text, + }), + _ => Err(MessageError::ConversionError( + "Only text content is supported by Perplexity".to_owned(), + )), + }) + .collect::, _>>()?, + + message::Message::Assistant { + content, + tool_calls, + } => { + if tool_calls.len() > 0 { + return Err(MessageError::ConversionError( + "Tool calls are not supported by Perplexity".to_owned(), + )); + } + + content + .into_iter() + .map(|content| Message { + role: Role::Assistant, + content: content.into(), + }) + .collect::>() + } + + _ => { + return Err(MessageError::ConversionError( + "Only user and assistant messages are supported by Perplexity".to_owned(), + )) + } + }) + } +} + impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; @@ -202,28 +258,34 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + // Add context documents to chat history + let prompt_with_context = completion_request.prompt_with_context(); + // Add preamble to messages (if available) - let mut messages = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + let mut messages: Vec = if let Some(preamble) = completion_request.preamble { + let message: message::Message = preamble.into(); + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.to_string().into()))? } else { vec![] }; - // Add context documents to chat history - let prompt_with_context = completion_request.prompt_with_context(); - // Add chat history to messages - messages.extend(completion_request.chat_history); + for message in completion_request.chat_history { + let converted: Vec = message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.to_string().into()))?; + messages.extend(converted); + } // Add user prompt to messages - messages.push(completion::Message { - role: "user".to_string(), - content: prompt_with_context, - }); + let user_messages: Vec = prompt_with_context + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.to_string().into()))?; + messages.extend(user_messages); + // Compose request let request = json!({ "model": self.model, "messages": messages,