Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support moonshot language model #223

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ required-features = ["derive"]
[[example]]
name = "xai_embeddings"
required-features = ["derive"]

[[example]]
name = "agent_with_moonshot"
required-features = ["derive"]
52 changes: 52 additions & 0 deletions rig-core/examples/agent_with_moonshot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use rig::agent::AgentBuilder;
use rig::providers::moonshot::{CompletionModel, MOONSHOT_CHAT};
use rig::{completion::Prompt, providers};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
println!("Running basic agent with moonshot");
basic_moonshot().await?;

println!("\nRunning moonshot agent with context");
context_moonshot().await?;

println!("\n\nAll agents ran successfully");
Ok(())
}

fn client() -> providers::moonshot::Client {
providers::moonshot::Client::from_env()
}

fn partial_agent_moonshot() -> AgentBuilder<CompletionModel> {
let client = client();
client.agent(MOONSHOT_CHAT)
}

async fn basic_moonshot() -> Result<(), anyhow::Error> {
let comedian_agent = partial_agent_moonshot()
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}

async fn context_moonshot() -> Result<(), anyhow::Error> {
let model = client().completion_model(MOONSHOT_CHAT);

// Create an agent with multiple context documents
let agent = AgentBuilder::new(model)
.preamble("Definition of a *glarb-glarb*: A glarb-glarb is an ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.build();

// Prompt the agent and print the response
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;

println!("{}", response);

Ok(())
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
pub mod cohere;
pub mod eternalai;
pub mod gemini;
pub mod openai;

Check warning on line 49 in rig-core/src/providers/mod.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/src/providers/mod.rs
pub mod perplexity;
pub mod xai;
pub mod moonshot;
threewebcode marked this conversation as resolved.
Show resolved Hide resolved
265 changes: 265 additions & 0 deletions rig-core/src/providers/moonshot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
//! Moonshot API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::moonshot;
//!
//! let client = moonshot::Client::new("YOUR_API_KEY");
//!
//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT);
//! ```

use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::time::Duration;

// ================================================================
// Main Moonshot Client
// ================================================================
const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Moonshot client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, MOONSHOT_API_BASE_URL)
}

/// Create a new Moonshot client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.timeout(Duration::from_secs(120))

Check failure on line 55 in rig-core/src/providers/moonshot.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

no method named `timeout` found for struct `reqwest::ClientBuilder` in the current scope
threewebcode marked this conversation as resolved.
Show resolved Hide resolved
.build()
.expect("Moonshot reqwest client should build"),
}
}

/// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set");
Self::new(&api_key)
}

fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::moonshot::{Client, self};
///
/// // Initialize the Moonshot client
/// let moonshot = Client::new("your-moonshot-api-key");
///
/// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
error: MoonshotError,
}

#[derive(Debug, Deserialize)]
struct MoonshotError {
message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

// ================================================================
// Moonshot Completion API
// ================================================================
pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k";

#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}

impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.error.message)
}
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;

fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message:
Message {
content: Some(content),
..
},
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
"Response did not contain a message".into(),
)),
}
}
}

#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
}

#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}

impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;

#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
threewebcode marked this conversation as resolved.
Show resolved Hide resolved
&self,
mut completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
} else {
vec![]
};

full_history.append(&mut completion_request.chat_history);

full_history.push(completion::Message {
role: "user".into(),
content: completion_request.prompt_with_context(),
});

let request = json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
});

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"Moonshot completion token usage: {}",
response.usage
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
Loading