Skip to content

Commit

Permalink
partial impl 4
Browse files Browse the repository at this point in the history
  • Loading branch information
0xMochan committed Jan 13, 2025
1 parent 7aa449a commit 63b6314
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 155 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

160 changes: 88 additions & 72 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ use futures::{stream, StreamExt, TryStreamExt};
use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
CompletionResponse, Document, ModelChoice, Prompt, PromptError,
},
message::Message,
tool::{Tool, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};
Expand Down Expand Up @@ -165,94 +166,109 @@ pub struct Agent<M: CompletionModel> {
impl<M: CompletionModel> Completion<M> for Agent<M> {
async fn completion(
&self,
prompt: &str,
prompt: impl Into<Message>,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());

Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(prompt.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", toolname);
None
}
})
.collect::<Vec<_>>()
.await;

Ok(self
let completion_request = self
.model
.completion_request(prompt)
.preamble(self.preamble.clone())
.messages(chat_history)
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
.additional_params_opt(self.additional_params.clone());

let agent = match prompt.into().rag_text() {
Some(text) => {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n(&text, *num_sample)
.await?
.into_iter()
.map(|(_, id, doc)| {
// Pretty print the document if possible for better readability
let text = serde_json::to_string_pretty(&doc)
.unwrap_or_else(|_| doc.to_string());

Document {
id,
text,
additional_props: HashMap::new(),
}
})
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
acc.extend(docs);
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids(&text, *num_sample)
.await?
.into_iter()
.map(|(_, id)| id)
.collect::<Vec<_>>(),
)
})
.try_fold(vec![], |mut acc, docs| async {
for doc in docs {
if let Some(tool) = self.tools.get(&doc) {
acc.push(tool.definition(text.into()).await)
} else {
tracing::warn!("Tool implementation not found in toolset: {}", doc);
}
}
Ok(acc)
})
.await
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let static_tools = stream::iter(self.static_tools.iter())
.filter_map(|toolname| async move {
if let Some(tool) = self.tools.get(toolname) {
Some(tool.definition(text.into()).await)
} else {
tracing::warn!(
"Tool implementation not found in toolset: {}",
toolname
);
None
}
})
.collect::<Vec<_>>()
.await;

completion_request
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
}
None => completion_request,
};
Ok(agent)
}
}

impl<M: CompletionModel> Prompt for Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
async fn prompt(&self, prompt: impl Into<String>) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}

impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
async fn chat(
&self,
prompt: impl Into<String>,
chat_history: Vec<Message>,
) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
CompletionResponse {
choice: ModelChoice::Message(msg),
Expand Down
5 changes: 4 additions & 1 deletion rig-core/src/cli_chatbot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::io::{self, Write};

use crate::completion::{Chat, Message, PromptError};
use crate::{
completion::{Chat, PromptError},
message::Message,
};

/// Utility function to create a simple REPL CLI chatbot from a type that implements the
/// `Chat` trait.
Expand Down
90 changes: 9 additions & 81 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::{json_utils, tool::ToolSetError, OneOrMany};
use crate::{json_utils, message::Message, tool::ToolSetError};

// Errors
#[derive(Debug, Error)]
Expand Down Expand Up @@ -102,78 +102,6 @@ pub enum PromptError {
ToolError(#[from] ToolSetError),
}

// ================================================================
// Request models
// ================================================================
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: OneOrMany<String>,
},
User {
content: OneOrMany<String>,
},
Assistant {
refusal: Option<String>,
content: OneOrMany<String>,
tool_calls: OneOrMany<ToolCall>,
},
Tool {
id: String,
content: String,
},
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction {
pub name: String,
pub arguments: String,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum UserContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
Image {
image_url: String,
detail: ImageDetail,
},
#[serde(rename = "input_audio")]
Audio { data: String, format: String },
}

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageDetail {
Low,
High,
Auto,
}

impl std::str::FromStr for ImageDetail {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"low" => Ok(ImageDetail::Low),
"high" => Ok(ImageDetail::High),
"auto" => Ok(ImageDetail::Auto),
_ => Err(()),
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Document {
pub id: String,
Expand Down Expand Up @@ -226,7 +154,7 @@ pub trait Prompt: Send + Sync {
/// If the tool does not exist, or the tool call fails, then an error is returned.
fn prompt(
&self,
prompt: &str,
prompt: impl Into<String>,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
}

Expand All @@ -242,7 +170,7 @@ pub trait Chat: Send + Sync {
/// If the tool does not exist, or the tool call fails, then an error is returned.
fn chat(
&self,
prompt: &str,
prompt: impl Into<String>,
chat_history: Vec<Message>,
) -> impl std::future::Future<Output = Result<String, PromptError>> + Send;
}
Expand Down Expand Up @@ -302,15 +230,15 @@ pub trait CompletionModel: Clone + Send + Sync {
+ Send;

/// Generates a completion request builder for the given `prompt`.
fn completion_request(&self, prompt: &str) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt.to_string())
fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
CompletionRequestBuilder::new(self.clone(), prompt)
}
}

/// Struct representing a general completion request that can be sent to a completion model provider.
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: String,
pub prompt: Message,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The chat history to be sent to the completion model provider
Expand Down Expand Up @@ -391,7 +319,7 @@ impl CompletionRequest {
/// Instead, use the [CompletionModel::completion_request] method.
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: String,
prompt: Message,
preamble: Option<String>,
chat_history: Vec<Message>,
documents: Vec<Document>,
Expand All @@ -402,10 +330,10 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
}

impl<M: CompletionModel> CompletionRequestBuilder<M> {
pub fn new(model: M, prompt: String) -> Self {
pub fn new(model: M, prompt: impl Into<Message>) -> Self {
Self {
model,
prompt,
prompt: prompt.into(),
preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
Expand Down
1 change: 1 addition & 0 deletions rig-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub mod embeddings;
pub mod extractor;
pub(crate) mod json_utils;
pub mod loaders;
pub mod message;
pub mod one_or_many;
pub mod providers;
pub mod tool;
Expand Down
Loading

0 comments on commit 63b6314

Please sign in to comment.