Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: new field for Message struct to support tool chain calls #253

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tokio-test = "0.4.4"
dotenvy = "0.15.7"

[features]
all = ["derive", "pdf", "rayon"]
Expand Down
95 changes: 95 additions & 0 deletions rig-core/examples/agent_with_tools_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::str;

use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Deserialize)]
struct QueryArgs {
query: String,
}

#[derive(Debug, thiserror::Error)]
#[error("Search error")]
struct SearchError;

#[derive(Serialize, Deserialize)]
struct SearchResults {
results: Vec<SearchResult>,
}

#[derive(Serialize, Deserialize)]
struct SearchResult {
title: String,
url: String,
}

#[derive(Deserialize, Serialize)]
struct SearchApiTool;
impl Tool for SearchApiTool {
const NAME: &'static str = "search";

type Error = SearchError;
type Args = QueryArgs;
type Output = SearchResults;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "search".to_string(),
description: "Search on the internet
."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = SearchResults {
results: vec![SearchResult {
title: format!("Example Website with terms: {}", args.query),
url: "https://example.com".to_string(),
}],
};
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Load environment variables
dotenvy::dotenv().ok();

// Create OpenAI client
let openai_client = providers::openai::Client::from_env();

// Create agent with a single context prompt and two tools
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are an assistant helping to find information on the internet. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(SearchApiTool)
.build();

// Prompt the agent and print the response
println!(
"Search Agent: {}",
calculator_agent
.prompt("Search for 'example' and tell me the title and url of each result you find")
.await?
);

Ok(())
}
2 changes: 1 addition & 1 deletion rig-core/examples/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn main() -> Result<(), anyhow::Error> {
),
Err(err) => {
println!("Error: {}! Prompting without additional context", err);
format!("{prompt}")
String::from(prompt)
}
})
// Chain a "prompt" operation which will prompt out agent with the final prompt
Expand Down
21 changes: 5 additions & 16 deletions rig-core/examples/debate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,17 @@ impl Debater {

let resp_a = self.gpt_4.chat(&prompt_a, history_a.clone()).await?;
println!("GPT-4:\n{}", resp_a);
history_a.push(Message {
role: "user".into(),
content: prompt_a.clone(),
});
history_a.push(Message {
role: "assistant".into(),
content: resp_a.clone(),
});
history_a.push(Message::user(&prompt_a));

history_a.push(Message::assistant(&resp_a));
println!("================================================================");

let resp_b = self.coral.chat(&resp_a, history_b.clone()).await?;
println!("Coral:\n{}", resp_b);
println!("================================================================");

history_b.push(Message {
role: "user".into(),
content: resp_a.clone(),
});
history_b.push(Message {
role: "assistant".into(),
content: resp_b.clone(),
});
history_b.push(Message::user(&resp_a));
history_b.push(Message::assistant(&resp_b));

last_resp_b = Some(resp_b)
}
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/cli_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@ pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
tracing::info!("Prompt:\n{}\n", input);

let response = chatbot.chat(input, chat_log.clone()).await?;
chat_log.push(Message {
role: "user".into(),
content: input.into(),
});
chat_log.push(Message {
role: "assistant".into(),
content: response.clone(),
});
chat_log.push(Message::user(input));
chat_log.push(Message::assistant(&response));

println!("========================== Response ============================");
println!("{response}");
Expand Down
39 changes: 37 additions & 2 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ pub struct Message {
/// "system", "user", or "assistant"
pub role: String,
pub content: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
}

impl Message {
pub fn system(content: &str) -> Self {
Self {
role: "system".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}

pub fn user(content: &str) -> Self {
Self {
role: "user".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}

pub fn assistant(content: &str) -> Self {
Self {
role: "assistant".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolCall {
pub call_id: String,
pub function_name: String,
pub function_params: serde_json::Value,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -208,7 +243,7 @@ pub trait Completion<M: CompletionModel> {

/// General completion response struct that contains the high-level completion choice
/// and the raw response.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CompletionResponse<T> {
/// The completion choice returned by the completion model provider
pub choice: ModelChoice,
Expand All @@ -217,7 +252,7 @@ pub struct CompletionResponse<T> {
}

/// Enum representing the high-level completion choice returned by the completion model provider.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum ModelChoice {
/// Represents a completion response as a message
Message(String),
Expand Down
2 changes: 1 addition & 1 deletion rig-core/src/embeddings/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ mod tests {
.unwrap();

result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.cmp(&fake_definition_2)
fake_definition_1.cmp(fake_definition_2)
});

assert_eq!(result.len(), 2);
Expand Down
5 changes: 1 addition & 4 deletions rig-core/src/providers/gemini/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ impl completion::CompletionModel for CompletionModel {

let prompt_with_context = completion_request.prompt_with_context();

full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
full_history.push(completion::Message::user(&prompt_with_context));

// Handle Gemini specific parameters
let additional_params = completion_request
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/providers/hyperbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,7 @@ impl completion::CompletionModel for CompletionModel {
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
vec![completion::Message::system(preamble)]
} else {
vec![]
};
Expand All @@ -269,10 +266,7 @@ impl completion::CompletionModel for CompletionModel {
let prompt_with_context = completion_request.prompt_with_context();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
full_history.push(completion::Message::user(&prompt_with_context));

let request = json!({
"model": self.model,
Expand Down
25 changes: 10 additions & 15 deletions rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub struct EmbeddingData {
pub index: usize,
}

#[derive(Clone, Debug, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
Expand Down Expand Up @@ -356,7 +356,7 @@ pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
/// `gpt-3.5-turbo-instruct` completion model
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
Expand Down Expand Up @@ -417,29 +417,29 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
}
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: Function,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolDefinition {
pub r#type: String,
pub function: completion::ToolDefinition,
Expand All @@ -454,7 +454,7 @@ impl From<completion::ToolDefinition> for ToolDefinition {
}
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Function {
pub name: String,
pub arguments: String,
Expand Down Expand Up @@ -486,10 +486,7 @@ impl completion::CompletionModel for CompletionModel {
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
vec![completion::Message::system(preamble)]
} else {
vec![]
};
Expand All @@ -501,10 +498,7 @@ impl completion::CompletionModel for CompletionModel {
let prompt_with_context = completion_request.prompt_with_context();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
full_history.push(completion::Message::user(&prompt_with_context));

let request = if completion_request.tools.is_empty() {
json!({
Expand Down Expand Up @@ -542,6 +536,7 @@ impl completion::CompletionModel for CompletionModel {
"OpenAI completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);

response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/providers/perplexity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,7 @@ impl completion::CompletionModel for CompletionModel {
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to messages (if available)
let mut messages = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
vec![completion::Message::system(preamble)]
} else {
vec![]
};
Expand All @@ -220,10 +217,7 @@ impl completion::CompletionModel for CompletionModel {
messages.extend(completion_request.chat_history);

// Add user prompt to messages
messages.push(completion::Message {
role: "user".to_string(),
content: prompt_with_context,
});
messages.push(completion::Message::user(&prompt_with_context));

let request = json!({
"model": self.model,
Expand Down
Loading