From 76dacd7450aa67cd8c496e58d09e72c83ad35ecd Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 10 Sep 2024 05:44:54 -0400 Subject: [PATCH 1/3] feat/providers/add perplexity --- rig-core/examples/perplexity_agent.rs | 34 ++++ rig-core/src/providers/mod.rs | 2 + rig-core/src/providers/perplexity.rs | 248 ++++++++++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 rig-core/examples/perplexity_agent.rs create mode 100644 rig-core/src/providers/perplexity.rs diff --git a/rig-core/examples/perplexity_agent.rs b/rig-core/examples/perplexity_agent.rs new file mode 100644 index 00000000..6f901f24 --- /dev/null +++ b/rig-core/examples/perplexity_agent.rs @@ -0,0 +1,34 @@ +use std::env; + +use rig::{ + completion::Prompt, + providers::{self, perplexity::LLAMA_3_1_70B_INSTRUCT}, +}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let client = providers::perplexity::Client::new( + &env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"), + ); + + // Create agent with a single context prompt + let agent = client + .agent(LLAMA_3_1_70B_INSTRUCT) + .preamble("Be precise and concise.") + .temperature(0.5) + .additional_params(json!({ + "return_related_questions": true, + "return_images": true + })) + .build(); + + // Prompt the agent and print the response + let response = agent + .prompt("When and where is the next solar eclipse?") + .await?; + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index a50f2794..521ee1e4 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -3,6 +3,7 @@ //! Currently, the following providers are supported: //! - Cohere //! - OpenAI +//! - Perplexity //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -39,3 +40,4 @@ //! be used with the Cohere provider client. pub mod cohere; pub mod openai; +pub mod perplexity; diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs new file mode 100644 index 00000000..1946b01d --- /dev/null +++ b/rig-core/src/providers/perplexity.rs @@ -0,0 +1,248 @@ +//! Perplexity API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::perplexity; +//! +//! let client = perplexity::Client::new("YOUR_API_KEY"); +//! +//! let llama_3_1_sonar_small_online = client.completion_model(perplexity::LLAMA_3_1_SONAR_SMALL_ONLINE); +//! ``` + +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError}, + extractor::ExtractorBuilder, + json_utils, + model::ModelBuilder, + rag::RagAgentBuilder, + vector_store::{NoIndex, VectorStoreIndex}, +}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +// ================================================================ +// Main Cohere Client +// ================================================================ +const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai/chat/completions"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, PERPLEXITY_API_BASE_URL) + } + + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("Perplexity reqwest client should build"), + } + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + pub fn model(&self, model: &str) -> ModelBuilder { + ModelBuilder::new(self.completion_model(model)) + } + + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } + + pub fn rag_agent( + &self, + model: &str, + ) -> RagAgentBuilder { + RagAgentBuilder::new(self.completion_model(model)) + } + + pub fn context_rag_agent( + &self, + model: &str, + ) -> RagAgentBuilder { + RagAgentBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +// ================================================================ +// Perplexity Completion API +// ================================================================ +/// `llama-3.1-sonar-small-128k-online` completion model +pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online"; +/// `llama-3.1-sonar-large-128k-online` completion model +pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online"; +/// `llama-3.1-sonar-huge-128k-online` completion model +pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online"; +/// `llama-3.1-sonar-small-128k-chat` completion model +pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat"; +/// `llama-3.1-sonar-large-128k-chat` completion model +pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat"; +/// `llama-3.1-8b-instruct` completion model +pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct"; +/// `llama-3.1-70b-instruct` completion model +pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub model: String, + pub object: String, + pub created: u64, + #[serde(default)] + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Deserialize, Debug)] +pub struct Message { + pub role: String, + pub content: String, +} + +#[derive(Deserialize, Debug)] +pub struct Delta { + pub role: String, + pub content: String, +} + +#[derive(Deserialize, Debug)] +pub struct Choice { + pub index: usize, + pub finish_reason: String, + pub message: Message, + pub delta: Delta, +} + +#[derive(Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(value: CompletionResponse) -> std::prelude::v1::Result { + match value.choices.as_slice() { + [Choice { + message: Message { content, .. }, + .. + }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.to_string()), + raw_response: value, + }), + _ => Err(CompletionError::ResponseError( + "Response did not contain a message or tool call".into(), + )), + } + } +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let mut messages = completion_request.chat_history.clone(); + if let Some(preamble) = completion_request.preamble { + messages.push(completion::Message { + role: "system".to_string(), + content: preamble, + }); + } + messages.push(completion::Message { + role: "user".to_string(), + content: completion_request.prompt, + }); + + let request = json!({ + "model": self.model, + "messages": messages, + "temperature": completion_request.temperature, + }); + + let response = self + .client + .post("/chat/completions") + .json( + &if let Some(ref params) = completion_request.additional_params { + json_utils::merge(request.clone(), params.clone()) + } else { + request.clone() + }, + ) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(completion) => Ok(completion.try_into()?), + ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + } + } +} From ba312a10352931299e0d84656cc5add4adfc9dec Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 11 Sep 2024 04:28:46 -0400 Subject: [PATCH 2/3] fix/client/wrong url, add from_env method --- rig-core/src/providers/perplexity.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 1946b01d..f940da56 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -26,7 +26,7 @@ use serde_json::json; // ================================================================ // Main Cohere Client // ================================================================ -const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai/chat/completions"; +const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai"; #[derive(Clone)] pub struct Client { @@ -39,6 +39,13 @@ impl Client { Self::from_url(api_key, PERPLEXITY_API_BASE_URL) } + /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"); + Self::new(&api_key) + } + pub fn from_url(api_key: &str, base_url: &str) -> Self { Self { base_url: base_url.to_string(), From b2c47fc17a7c7d386c226bbf3cdd94b44842d1fc Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 11 Sep 2024 04:35:46 -0400 Subject: [PATCH 3/3] example/improve prompt text --- rig-core/examples/perplexity_agent.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/examples/perplexity_agent.rs b/rig-core/examples/perplexity_agent.rs index 6f901f24..a8ae8d96 100644 --- a/rig-core/examples/perplexity_agent.rs +++ b/rig-core/examples/perplexity_agent.rs @@ -26,7 +26,7 @@ async fn main() -> Result<(), anyhow::Error> { // Prompt the agent and print the response let response = agent - .prompt("When and where is the next solar eclipse?") + .prompt("When and where and what type is the next solar eclipse?") .await?; println!("{}", response);