Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(message): update message api types #199

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

161 changes: 90 additions & 71 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
CompletionResponse, Document, ModelChoice, Prompt, PromptError,
},
message::Message,
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
Expand Down Expand Up @@ -165,94 +166,112 @@ pub struct Agent<M: CompletionModel> {
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: &str,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());
let prompt = prompt.into();
let rag_text = prompt.rag_text().clone();

Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
.collect::<Vec<_>>()
.await;

Ok(self
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
.additional_params_opt(self.additional_params.clone());

let agent = match &rag_text {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(&text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());

Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(&text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;

completion_request
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => completion_request,
};
Ok(agent)
}
}

impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}

impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
choice: ModelChoice::Message(msg),
Expand Down
5 changes: 4 additions & 1 deletion rig-core/src/cli_chatbot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::io::{self, Write};

use crate::completion::{Chat, Message, PromptError};
use crate::{
completion::{Chat, PromptError},
message::Message,
};

/// Utility function to create a simple REPL CLI chatbot from a type that implements the
/// `Chat` trait.
Expand Down
91 changes: 52 additions & 39 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::{json_utils, tool::ToolSetError};
use crate::{
json_utils,
message::{Message, UserContent},
tool::ToolSetError,
};

// Errors
#[derive(Debug, Error)]
Expand Down Expand Up @@ -102,16 +106,6 @@ pub enum PromptError {
ToolError(#[from] ToolSetError),
}

// ================================================================
// Request models
// ================================================================
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Message {
/// "system", "user", or "assistant"
pub role: String,
pub content: String,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Document {
pub id: String,
Expand Down Expand Up @@ -164,7 +158,7 @@ pub trait Prompt: Send + Sync {
/// If the tool does not exist, or the tool call fails, then an error is returned.
fn prompt(
&self,
prompt: &str,
prompt: impl Into<Message> + Send,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
}

Expand All @@ -180,7 +174,7 @@ pub trait Chat: Send + Sync {
/// If the tool does not exist, or the tool call fails, then an error is returned.
fn chat(
&self,
prompt: &str,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
}
Expand All @@ -200,7 +194,7 @@ pub trait Completion<M: CompletionModel> {
/// contain the `preamble` provided when creating the agent.
fn completion(
&self,
prompt: &str,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>> + Send;
}
Expand Down Expand Up @@ -240,15 +234,15 @@ pub trait CompletionModel: Clone + Send + Sync {
+ Send;

/// Generates a completion request builder for the given `prompt`.
fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt.to_string())
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt)
}
}

/// Struct representing a general completion request that can be sent to a completion model provider.
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: String,
pub prompt: Message,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The chat history to be sent to the completion model provider
Expand All @@ -266,20 +260,26 @@ pub struct CompletionRequest {
}

impl CompletionRequest {
pub(crate) fn prompt_with_context(&self) -> String {
if !self.documents.is_empty() {
format!(
"<attachments>\n{}</attachments>\n\n{}",
self.documents
pub(crate) fn prompt_with_context(&self) -> Message {
let mut new_prompt = self.prompt.clone();
if let Message::User { ref mut content } = new_prompt {
if !self.documents.is_empty() {
let attachments = self
.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join(""),
self.prompt
)
} else {
self.prompt.clone()
.join("");
let formatted_content = format!("<attachments>\n{}</attachments>", attachments);
content.insert(
0,
UserContent::Text {
text: formatted_content,
},
);
}
}
new_prompt
}
}

Expand Down Expand Up @@ -329,7 +329,7 @@ impl CompletionRequest {
/// Instead, use the [CompletionModel::completion_request] method.
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: String,
prompt: Message,
preamble: Option<String>,
chat_history: Vec<Message>,
documents: Vec<Document>,
Expand All @@ -340,10 +340,10 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
}

impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn new(model: M, prompt: String) -> Self {
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
Self {
model,
prompt,
prompt: prompt.into(),
preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
Expand Down Expand Up @@ -475,6 +475,8 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {

#[cfg(test)]
mod tests {
use crate::OneOrMany;

use super::*;

#[test]
Expand Down Expand Up @@ -525,7 +527,7 @@ mod tests {
};

let request = CompletionRequest {
prompt: "What is the capital of France?".to_string(),
prompt: "What is the capital of France?".into(),
preamble: None,
chat_history: Vec::new(),
documents: vec![doc1, doc2],
Expand All @@ -535,14 +537,25 @@ mod tests {
additional_params: None,
};

let expected = concat!(
"<attachments>\n",
"<file id: doc1>\nDocument 1 text.\n</file>\n",
"<file id: doc2>\nDocument 2 text.\n</file>\n",
"</attachments>\n\n",
"What is the capital of France?"
)
.to_string();
let expected = Message::User {
content: OneOrMany::many(vec![
UserContent::Text {
text: concat!(
"<attachments>\n",
"<file id: doc1>\nDocument 1 text.\n</file>\n",
"<file id: doc2>\nDocument 2 text.\n</file>\n",
"</attachments>\n"
)
.to_string(),
},
UserContent::Text {
text: "What is the capital of France?".to_string(),
},
])
.expect("This has more than 1 item"),
};

request.prompt_with_context();

assert_eq!(request.prompt_with_context(), expected);
}
Expand Down
Loading