Skip to content

Commit

Permalink
Merge pull request #21 from 0xPlaygrounds/refactor/agent-type
Browse files Browse the repository at this point in the history
Refactor: Merge and generalize Agent types
  • Loading branch information
cvauclair authored Sep 17, 2024
2 parents f08dd2e + aa75c30 commit 3c67742
Show file tree
Hide file tree
Showing 12 changed files with 322 additions and 464 deletions.
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
.tool_rag_agent("gpt-4")
.agent("gpt-4")
.preamble(
"You are an assistant here to help the user select which tool is most appropriate to perform arithmetic operations.
Follow these instructions closely.
Expand Down
5 changes: 2 additions & 3 deletions rig-core/examples/multi_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use rig::{
agent::{Agent, AgentBuilder},
cli_chatbot::cli_chatbot,
completion::{Chat, CompletionModel, Message, PromptError},
model::{Model, ModelBuilder},
providers::openai::Client as OpenAIClient,
};

Expand All @@ -14,7 +13,7 @@ use rig::{
/// prompt in english, before answering it with GPT-4. The answer in english is returned.
struct EnglishTranslator<M: CompletionModel> {
translator_agent: Agent<M>,
gpt4: Model<M>,
gpt4: Agent<M>,
}

impl<M: CompletionModel> EnglishTranslator<M> {
Expand All @@ -29,7 +28,7 @@ impl<M: CompletionModel> EnglishTranslator<M> {
.build(),

// Create the GPT4 model
gpt4: ModelBuilder::new(model).build()
gpt4: AgentBuilder::new(model).build()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Create vector store index
let index = vector_store.index(embedding_model);

let rag_agent = openai_client.context_rag_agent("gpt-4")
let rag_agent = openai_client.agent("gpt-4")
.preamble("
You are a dictionary assistant here to assist the user in understanding the meaning of words.
You will find additional non-standard word definitions that could be useful below.
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
.tool_rag_agent("gpt-4")
.agent("gpt-4")
.preamble("You are a calculator here to help the user perform arithmetic operations.")
// Add a dynamic tool source with a sample rate of 1 (i.e.: only
// 1 additional tool will be added to prompts)
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/simple_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ async fn main() {
// Create OpenAI client and model
let openai_client = openai::Client::from_env();

let gpt4 = openai_client.model("gpt-4").build();
let gpt4 = openai_client.agent("gpt-4").build();

// Prompt the model and print its response
let response = gpt4
Expand Down
217 changes: 168 additions & 49 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
//! This module contains the implementation of the [Agent] struct and its builder.
//!
//! The [Agent] struct represents an LLM agent, which combines an LLM model with a preamble (system prompt),
//! a set of context documents, and a set of static tools. The agent can be used to interact with the LLM model
//! by providing prompts and chat history without having to provide the preamble and other parameters everytime.
//! a set of context documents, and a set of tools. Note: both context documents and tools can be either
//! static (i.e.: they are always provided) or dynamic (i.e.: they are RAGged at prompt-time).
//!
//! The [Agent] struct is highly configurable, allowing the user to define anything from
//! a simple bot with a specific system prompt to a complex RAG system with a set of dynamic
//! context documents and tools.
//!
//! The [Agent] struct implements the [Completion] and [Prompt] traits, allowing it to be used for generating
//! completions responses and prompts. The [Agent] struct also implements the [Chat] trait, which allows it to
//! be used for generating chat completions.
//!
//! The [AgentBuilder] implements the builder pattern for creating instances of [Agent].
//! It allows configuring the model, preamble, context documents, static tools, temperature, and additional parameters
Expand Down Expand Up @@ -52,16 +60,63 @@
//! .await
//! .expect("Failed to send completion request");
//! ```
//!
//! RAG Agent example
//! ```rust
//! use rig::{
//! completion::Prompt,
//! embeddings::EmbeddingsBuilder,
//! providers::openai,
//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
//! };
//!
//! // Initialize OpenAI client
//! let openai = openai::Client::from_env();
//!
//! // Initialize OpenAI embedding model
//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
//!
//! // Create vector store, compute embeddings and load them in the store
//! let mut vector_store = InMemoryVectorStore::default();
//!
//! let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
//! .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
//! .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
//! .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
//! .build()
//! .await
//! .expect("Failed to build embeddings");
//!
//! vector_store.add_documents(embeddings)
//! .await
//! .expect("Failed to add documents");
//!
//! // Create vector store index
//! let index = vector_store.index(embedding_model);
//!
//! let agent = openai.agent(openai::GPT_4O)
//! .preamble("
//! You are a dictionary assistant here to assist the user in understanding the meaning of words.
//! You will find additional non-standard word definitions that could be useful below.
//! ")
//! .dynamic_context(1, index)
//! .build();
//!
//! // Prompt the agent and print the response
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::collections::HashMap;

use futures::{stream, StreamExt};
use futures::{stream, StreamExt, TryStreamExt};

use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
},
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};

/// Struct reprensenting an LLM agent. An agent is an LLM model combined with a preamble
Expand All @@ -85,52 +140,24 @@ use crate::{
/// .expect("Failed to prompt the agent");
/// ```
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's `gpt-3.5-turbo-1106`, Cohere's `command-r`)
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: String,
/// Context documents always available to the agent
context: Vec<Document>,
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
tools: ToolSet,
}

impl<M: CompletionModel> Agent<M> {
/// Create a new Agent
pub fn new(
model: M,
preamble: String,
static_context: Vec<String>,
static_tools: Vec<impl Tool + 'static>,
temperature: Option<f64>,
additional_params: Option<serde_json::Value>,
) -> Self {
let static_tools_ids = static_tools.iter().map(|tool| tool.name()).collect();

Self {
model,
preamble,
context: static_context
.into_iter()
.enumerate()
.map(|(i, doc)| Document {
id: format!("static_doc_{}", i),
text: doc,
additional_props: HashMap::new(),
})
.collect(),
tools: ToolSet::from_tools(static_tools),
static_tools: static_tools_ids,
temperature,
additional_params,
}
}
pub tools: ToolSet,
}

impl<M: CompletionModel> Completion<M> for Agent<M> {
Expand All @@ -139,12 +166,64 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
prompt: &str,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let tool_definitions = stream::iter(self.static_tools.iter())
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| {
// Pretty print the document if possible for better readability
let doc_text = serde_json::to_string_pretty(&doc.document)
.unwrap_or_else(|_| doc.document.to_string());

Document {
id: doc.id,
text: doc_text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| doc)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::error!(target: "rig", "Agent static tool {} not found", toolname);
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
Expand All @@ -156,8 +235,8 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents(self.context.clone())
.tools(tool_definitions.clone())
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.additional_params_opt(self.additional_params.clone()))
}
Expand Down Expand Up @@ -206,12 +285,23 @@ impl<M: CompletionModel> Chat for Agent<M> {
/// .build();
/// ```
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// System prompt
preamble: Option<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
static_tools: Vec<String>,
temperature: Option<f64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
}

Expand All @@ -224,13 +314,15 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_tools: vec![],
temperature: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
}
}

/// Set the preamble of the agent
pub fn preamble(mut self, doc: &str) -> Self {
self.preamble = Some(doc.into());
/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self
}

Expand Down Expand Up @@ -262,6 +354,31 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

/// Add some dynamic context to the agent. On each prompt, `sample` documents from the
/// dynamic context will be inserted in the request.
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context)));
self
}

/// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
/// dynamic toolset will be inserted in the request.
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}

/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
Expand All @@ -278,12 +395,14 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
preamble: self.preamble.unwrap_or_else(|| "".into()),
context: self.static_context,
tools: self.tools,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
}
}
}
Loading

0 comments on commit 3c67742

Please sign in to comment.