Skip to content

Commit

Permalink
featadd new field for Message struct to support tool chain calls, it …
Browse files Browse the repository at this point in the history
…breaks current usage of Message api as you need a new field
  • Loading branch information
carlos-verdes committed Jan 29, 2025
1 parent 3353c21 commit 501f891
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 87 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tokio-test = "0.4.4"
dotenvy = "0.15.7"

[features]
all = ["derive", "pdf", "rayon"]
Expand Down
95 changes: 95 additions & 0 deletions rig-core/examples/agent_with_tools_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::str;

use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Deserialize)]
struct QueryArgs {
query: String,
}

#[derive(Debug, thiserror::Error)]
#[error("Search error")]
struct SearchError;

#[derive(Serialize, Deserialize)]
struct SearchResults {
results: Vec<SearchResult>,
}

#[derive(Serialize, Deserialize)]
struct SearchResult {
title: String,
url: String,
}

#[derive(Deserialize, Serialize)]
struct SearchApiTool;
impl Tool for SearchApiTool {
const NAME: &'static str = "search";

type Error = SearchError;
type Args = QueryArgs;
type Output = SearchResults;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "search".to_string(),
description: "Search on the internet
."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = SearchResults {
results: vec![SearchResult {
title: format!("Example Website with terms: {}", args.query),
url: "https://example.com".to_string(),
}],
};
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Load environment variables
dotenvy::dotenv().ok();

// Create OpenAI client
let openai_client = providers::openai::Client::from_env();

// Create agent with a single context prompt and two tools
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are an assistant helping to find information on the internet. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(SearchApiTool)
.build();

// Prompt the agent and print the response
println!(
"Search Agent: {}",
calculator_agent
.prompt("Search for 'example' and tell me the title and url of each result you find")
.await?
);

Ok(())
}
2 changes: 1 addition & 1 deletion rig-core/examples/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn main() -> Result<(), anyhow::Error> {
),
Err(err) => {
println!("Error: {}! Prompting without additional context", err);
format!("{prompt}")
String::from(prompt)
}
})
// Chain a "prompt" operation which will prompt out agent with the final prompt
Expand Down
21 changes: 5 additions & 16 deletions rig-core/examples/debate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,17 @@ impl Debater {

let resp_a = self.gpt_4.chat(&prompt_a, history_a.clone()).await?;
println!("GPT-4:\n{}", resp_a);
history_a.push(Message {
role: "user".into(),
content: prompt_a.clone(),
});
history_a.push(Message {
role: "assistant".into(),
content: resp_a.clone(),
});
history_a.push(Message::user(&prompt_a));

history_a.push(Message::assistant(&resp_a));
println!("================================================================");

let resp_b = self.coral.chat(&resp_a, history_b.clone()).await?;
println!("Coral:\n{}", resp_b);
println!("================================================================");

history_b.push(Message {
role: "user".into(),
content: resp_a.clone(),
});
history_b.push(Message {
role: "assistant".into(),
content: resp_b.clone(),
});
history_b.push(Message::user(&resp_a));
history_b.push(Message::assistant(&resp_b));

last_resp_b = Some(resp_b)
}
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/cli_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@ pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
tracing::info!("Prompt:\n{}\n", input);

let response = chatbot.chat(input, chat_log.clone()).await?;
chat_log.push(Message {
role: "user".into(),
content: input.into(),
});
chat_log.push(Message {
role: "assistant".into(),
content: response.clone(),
});
chat_log.push(Message::user(input));
chat_log.push(Message::assistant(&response));

println!("========================== Response ============================");
println!("{response}");
Expand Down
39 changes: 37 additions & 2 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ pub struct Message {
/// "system", "user", or "assistant"
pub role: String,
pub content: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
}

impl Message {
pub fn system(content: &str) -> Self {
Self {
role: "system".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}

pub fn user(content: &str) -> Self {
Self {
role: "user".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}

pub fn assistant(content: &str) -> Self {
Self {
role: "assistant".to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolCall {
pub call_id: String,
pub function_name: String,
pub function_params: serde_json::Value,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -208,7 +243,7 @@ pub trait Completion<M: CompletionModel> {

/// General completion response struct that contains the high-level completion choice
/// and the raw response.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CompletionResponse<T> {
/// The completion choice returned by the completion model provider
pub choice: ModelChoice,
Expand All @@ -217,7 +252,7 @@ pub struct CompletionResponse<T> {
}

/// Enum representing the high-level completion choice returned by the completion model provider.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum ModelChoice {
/// Represents a completion response as a message
Message(String),
Expand Down
2 changes: 1 addition & 1 deletion rig-core/src/embeddings/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ mod tests {
.unwrap();

result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.cmp(&fake_definition_2)
fake_definition_1.cmp(fake_definition_2)
});

assert_eq!(result.len(), 2);
Expand Down
5 changes: 1 addition & 4 deletions rig-core/src/providers/gemini/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ impl completion::CompletionModel for CompletionModel {

let prompt_with_context = completion_request.prompt_with_context();

full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
full_history.push(completion::Message::user(&prompt_with_context));

// Handle Gemini specific parameters
let additional_params = completion_request
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/providers/hyperbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,7 @@ impl completion::CompletionModel for CompletionModel {
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
vec![completion::Message::system(preamble)]
} else {
vec![]
};
Expand All @@ -269,10 +266,7 @@ impl completion::CompletionModel for CompletionModel {
let prompt_with_context = completion_request.prompt_with_context();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
full_history.push(completion::Message::user(&prompt_with_context));

let request = json!({
"model": self.model,
Expand Down
Loading

0 comments on commit 501f891

Please sign in to comment.