diff --git a/Cargo.lock b/Cargo.lock index e09baa4c..fc0b179d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6413,6 +6413,7 @@ dependencies = [ "assert_fs", "async-stream", "bytes", + "dotenvy", "futures", "glob", "lopdf", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index a72ca0aa..be80e5ae 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -37,6 +37,7 @@ assert_fs = "1.1.2" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" tokio-test = "0.4.4" +dotenvy = "0.15.7" [features] all = ["derive", "pdf", "rayon"] diff --git a/rig-core/examples/agent_with_tools_api.rs b/rig-core/examples/agent_with_tools_api.rs new file mode 100644 index 00000000..ebfe5b2f --- /dev/null +++ b/rig-core/examples/agent_with_tools_api.rs @@ -0,0 +1,95 @@ +use std::str; + +use anyhow::Result; +use rig::{ + completion::{Prompt, ToolDefinition}, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct QueryArgs { + query: String, +} + +#[derive(Debug, thiserror::Error)] +#[error("Search error")] +struct SearchError; + +#[derive(Serialize, Deserialize)] +struct SearchResults { + results: Vec, +} + +#[derive(Serialize, Deserialize)] +struct SearchResult { + title: String, + url: String, +} + +#[derive(Deserialize, Serialize)] +struct SearchApiTool; +impl Tool for SearchApiTool { + const NAME: &'static str = "search"; + + type Error = SearchError; + type Args = QueryArgs; + type Output = SearchResults; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "search".to_string(), + description: "Search on the internet + ." + .to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = SearchResults { + results: vec![SearchResult { + title: format!("Example Website with terms: {}", args.query), + url: "https://example.com".to_string(), + }], + }; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Load environment variables + dotenvy::dotenv().ok(); + + // Create OpenAI client + let openai_client = providers::openai::Client::from_env(); + + // Create agent with a single context prompt and two tools + let calculator_agent = openai_client + .agent(providers::openai::GPT_4O) + .preamble("You are an assistant helping to find information on the internet. Use the tools provided to answer the user's question.") + .max_tokens(1024) + .tool(SearchApiTool) + .build(); + + // Prompt the agent and print the response + println!( + "Search Agent: {}", + calculator_agent + .prompt("Search for 'example' and tell me the title and url of each result you find") + .await? + ); + + Ok(()) +} diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index 0c2c5c65..8fed6367 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -60,7 +60,7 @@ async fn main() -> Result<(), anyhow::Error> { ), Err(err) => { println!("Error: {}! Prompting without additional context", err); - format!("{prompt}") + String::from(prompt) } }) // Chain a "prompt" operation which will prompt out agent with the final prompt diff --git a/rig-core/examples/debate.rs b/rig-core/examples/debate.rs index 81e81265..3382c309 100644 --- a/rig-core/examples/debate.rs +++ b/rig-core/examples/debate.rs @@ -43,28 +43,17 @@ impl Debater { let resp_a = self.gpt_4.chat(&prompt_a, history_a.clone()).await?; println!("GPT-4:\n{}", resp_a); - history_a.push(Message { - role: "user".into(), - content: prompt_a.clone(), - }); - history_a.push(Message { - role: "assistant".into(), - content: resp_a.clone(), - }); + history_a.push(Message::user(&prompt_a)); + + history_a.push(Message::assistant(&resp_a)); println!("================================================================"); let resp_b = self.coral.chat(&resp_a, history_b.clone()).await?; println!("Coral:\n{}", resp_b); println!("================================================================"); - history_b.push(Message { - role: "user".into(), - content: resp_a.clone(), - }); - history_b.push(Message { - role: "assistant".into(), - content: resp_b.clone(), - }); + history_b.push(Message::user(&resp_a)); + history_b.push(Message::assistant(&resp_b)); last_resp_b = Some(resp_b) } diff --git a/rig-core/src/cli_chatbot.rs b/rig-core/src/cli_chatbot.rs index df33fa8e..0a74aaec 100644 --- a/rig-core/src/cli_chatbot.rs +++ b/rig-core/src/cli_chatbot.rs @@ -27,14 +27,8 @@ pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> { tracing::info!("Prompt:\n{}\n", input); let response = chatbot.chat(input, chat_log.clone()).await?; - chat_log.push(Message { - role: "user".into(), - content: input.into(), - }); - chat_log.push(Message { - role: "assistant".into(), - content: response.clone(), - }); + chat_log.push(Message::user(input)); + chat_log.push(Message::assistant(&response)); println!("========================== Response ============================"); println!("{response}"); diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 48ce8f8b..4fc61d9d 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -111,6 +111,41 @@ pub struct Message { /// "system", "user", or "assistant" pub role: String, pub content: String, + #[serde(default)] + pub tool_calls: Vec, +} + +impl Message { + pub fn system(content: &str) -> Self { + Self { + role: "system".to_string(), + content: content.to_string(), + tool_calls: Vec::new(), + } + } + + pub fn user(content: &str) -> Self { + Self { + role: "user".to_string(), + content: content.to_string(), + tool_calls: Vec::new(), + } + } + + pub fn assistant(content: &str) -> Self { + Self { + role: "assistant".to_string(), + content: content.to_string(), + tool_calls: Vec::new(), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolCall { + pub call_id: String, + pub function_name: String, + pub function_params: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -208,7 +243,7 @@ pub trait Completion { /// General completion response struct that contains the high-level completion choice /// and the raw response. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CompletionResponse { /// The completion choice returned by the completion model provider pub choice: ModelChoice, @@ -217,7 +252,7 @@ pub struct CompletionResponse { } /// Enum representing the high-level completion choice returned by the completion model provider. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum ModelChoice { /// Represents a completion response as a message Message(String), diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 389e17ba..2f70259b 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -366,7 +366,7 @@ mod tests { .unwrap(); result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { - fake_definition_1.cmp(&fake_definition_2) + fake_definition_1.cmp(fake_definition_2) }); assert_eq!(result.len(), 2); diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index a6606a40..6ff22346 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -55,10 +55,7 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); - full_history.push(completion::Message { - role: "user".into(), - content: prompt_with_context, - }); + full_history.push(completion::Message::user(&prompt_with_context)); // Handle Gemini specific parameters let additional_params = completion_request diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index ae6e2495..34d26606 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -254,10 +254,7 @@ impl completion::CompletionModel for CompletionModel { ) -> Result, CompletionError> { // Add preamble to chat history (if available) let mut full_history = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + vec![completion::Message::system(preamble)] } else { vec![] }; @@ -269,10 +266,7 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); // Add context documents to chat history - full_history.push(completion::Message { - role: "user".into(), - content: prompt_with_context, - }); + full_history.push(completion::Message::user(&prompt_with_context)); let request = json!({ "model": self.model, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index d6ecabe9..2adf2049 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -219,7 +219,7 @@ pub struct EmbeddingData { pub index: usize, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Usage { pub prompt_tokens: usize, pub total_tokens: usize, @@ -356,7 +356,7 @@ pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106"; /// `gpt-3.5-turbo-instruct` completion model pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct CompletionResponse { pub id: String, pub object: String, @@ -417,7 +417,7 @@ impl TryFrom for completion::CompletionResponse, pub tool_calls: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ToolCall { pub id: String, pub r#type: String, pub function: Function, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ToolDefinition { pub r#type: String, pub function: completion::ToolDefinition, @@ -454,7 +454,7 @@ impl From for ToolDefinition { } } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Function { pub name: String, pub arguments: String, @@ -486,10 +486,7 @@ impl completion::CompletionModel for CompletionModel { ) -> Result, CompletionError> { // Add preamble to chat history (if available) let mut full_history = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + vec![completion::Message::system(preamble)] } else { vec![] }; @@ -501,10 +498,7 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); // Add context documents to chat history - full_history.push(completion::Message { - role: "user".into(), - content: prompt_with_context, - }); + full_history.push(completion::Message::user(&prompt_with_context)); let request = if completion_request.tools.is_empty() { json!({ @@ -542,6 +536,7 @@ impl completion::CompletionModel for CompletionModel { "OpenAI completion token usage: {:?}", response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) ); + response.try_into() } ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index e82d39dc..0e44f343 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -205,10 +205,7 @@ impl completion::CompletionModel for CompletionModel { ) -> Result, CompletionError> { // Add preamble to messages (if available) let mut messages = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + vec![completion::Message::system(preamble)] } else { vec![] }; @@ -220,10 +217,7 @@ impl completion::CompletionModel for CompletionModel { messages.extend(completion_request.chat_history); // Add user prompt to messages - messages.push(completion::Message { - role: "user".to_string(), - content: prompt_with_context, - }); + messages.push(completion::Message::user(&prompt_with_context)); let request = json!({ "model": self.model, diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index e9109cf1..f8435626 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -44,10 +44,7 @@ impl completion::CompletionModel for CompletionModel { mut completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { let mut messages = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + vec![completion::Message::system(preamble)] } else { vec![] }; @@ -55,10 +52,7 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); - messages.push(completion::Message { - role: "user".into(), - content: prompt_with_context, - }); + messages.push(completion::Message::user(&prompt_with_context)); let mut request = if completion_request.tools.is_empty() { json!({ diff --git a/rig-eternalai/src/providers/eternalai.rs b/rig-eternalai/src/providers/eternalai.rs index 4fd15019..435bf160 100644 --- a/rig-eternalai/src/providers/eternalai.rs +++ b/rig-eternalai/src/providers/eternalai.rs @@ -463,10 +463,7 @@ impl completion::CompletionModel for CompletionModel { ) -> Result, CompletionError> { // Add preamble to chat history (if available) let mut full_history = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] + vec![completion::Message::system(preamble)] } else { vec![] }; @@ -503,10 +500,7 @@ impl completion::CompletionModel for CompletionModel { tracing::info!("on-chain sytem prompt is none") } Some(value) => { - let temp = completion::Message { - role: "system".into(), - content: value, - }; + let temp = completion::Message::system(&value); full_history.push(temp); } } @@ -519,10 +513,7 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); // Add context documents to chat history - full_history.push(completion::Message { - role: "user".into(), - content: prompt_with_context, - }); + full_history.push(completion::Message::user(&prompt_with_context)); let mut chain_id = self.chain_id.clone(); if chain_id.is_empty() {