diff --git a/examples/chatbot.rs b/examples/basic_chatbot.rs similarity index 100% rename from examples/chatbot.rs rename to examples/basic_chatbot.rs diff --git a/examples/chat_api_chatbot.rs b/examples/chat_api_chatbot.rs new file mode 100644 index 0000000..27700d5 --- /dev/null +++ b/examples/chat_api_chatbot.rs @@ -0,0 +1,52 @@ +use ollama_rs::{ + generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, + Ollama, +}; +use tokio::io::{stdout, AsyncWriteExt}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let ollama = Ollama::default(); + + let mut stdout = stdout(); + + let mut messages: Vec = vec![]; + + loop { + stdout.write_all(b"\n> ").await?; + stdout.flush().await?; + + let mut input = String::new(); + std::io::stdin().read_line(&mut input)?; + + let input = input.trim_end(); + if input.eq_ignore_ascii_case("exit") { + break; + } + + let user_message = ChatMessage::user(input.to_string()); + messages.push(user_message); + + let mut stream: ChatMessageResponseStream = ollama + .send_chat_messages_stream(ChatMessageRequest::new( + "llama2:latest".to_string(), + messages.clone(), + )) + .await?; + + let mut response = String::new(); + while let Some(Ok(res)) = stream.next().await { + if let Some(assistant_message) = res.message { + stdout + .write_all(assistant_message.content.as_bytes()) + .await?; + stdout.flush().await?; + response += assistant_message.content.as_str(); + } + } + messages.push(ChatMessage::assistant(response)); + } + + Ok(()) +} diff --git a/src/generation.rs b/src/generation.rs index 7442e54..9e28528 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -1,3 +1,5 @@ +pub mod chat; pub mod completion; pub mod embeddings; +pub mod format; pub mod options; diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs new file mode 100644 index 0000000..362aa74 --- /dev/null +++ b/src/generation/chat/mod.rs @@ -0,0 +1,154 @@ +use serde::{Deserialize, Serialize}; + +use crate::Ollama; + +pub mod request; + +use request::ChatMessageRequest; + +#[cfg(feature = "stream")] +/// A stream of `ChatMessageResponse` objects +pub type ChatMessageResponseStream = + std::pin::Pin>>>; + +impl Ollama { + #[cfg(feature = "stream")] + /// Chat message generation with streaming. + /// Returns a stream of `ChatMessageResponse` objects + pub async fn send_chat_messages_stream( + &self, + request: ChatMessageRequest, + ) -> crate::error::Result { + use tokio_stream::StreamExt; + + let mut request = request; + request.stream = true; + + let uri = format!("{}/api/chat", self.uri()); + let serialized = serde_json::to_string(&request) + .map_err(|e| e.to_string()) + .unwrap(); + let res = self + .reqwest_client + .post(uri) + .body(serialized) + .send() + .await + .map_err(|e| e.to_string())?; + + if !res.status().is_success() { + return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); + } + + let stream = Box::new(res.bytes_stream().map(|res| match res { + Ok(bytes) => { + let res = serde_json::from_slice::(&bytes); + match res { + Ok(res) => Ok(res), + Err(e) => { + eprintln!("Failed to deserialize response: {}", e); + Err(()) + } + } + } + Err(e) => { + eprintln!("Failed to read response: {}", e); + Err(()) + } + })); + + Ok(std::pin::Pin::from(stream)) + } + + /// Chat message generation. + /// Returns a `ChatMessageResponse` object + pub async fn send_chat_messages( + &self, + request: ChatMessageRequest, + ) -> crate::error::Result { + let mut request = request; + request.stream = false; + + let uri = format!("{}/api/chat", self.uri()); + let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; + let res = self + .reqwest_client + .post(uri) + .body(serialized) + .send() + .await + .map_err(|e| e.to_string())?; + + if !res.status().is_success() { + return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); + } + + let bytes = res.bytes().await.map_err(|e| e.to_string())?; + let res = + serde_json::from_slice::(&bytes).map_err(|e| e.to_string())?; + + Ok(res) + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatMessageResponse { + /// The name of the model used for the completion. + pub model: String, + /// The creation time of the completion, in such format: `2023-08-04T08:52:19.385406455-07:00`. + pub created_at: String, + /// The generated chat message. + pub message: Option, + pub done: bool, + #[serde(flatten)] + /// The final data of the completion. This is only present if the completion is done. + pub final_data: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatMessageFinalResponseData { + /// Time spent generating the response + pub total_duration: u64, + /// Number of tokens in the prompt + pub prompt_eval_count: u16, + /// Time spent in nanoseconds evaluating the prompt + pub prompt_eval_duration: u64, + /// Number of tokens the response + pub eval_count: u16, + /// Time in nanoseconds spent generating the response + pub eval_duration: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: MessageRole, + pub content: String, +} + +impl ChatMessage { + pub fn new(role: MessageRole, content: String) -> Self { + Self { role, content } + } + + pub fn user(content: String) -> Self { + Self::new(MessageRole::User, content) + } + + pub fn assistant(content: String) -> Self { + Self::new(MessageRole::Assistant, content) + } + + pub fn system(content: String) -> Self { + Self::new(MessageRole::System, content) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageRole { + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "system")] + System, +} diff --git a/src/generation/chat/request.rs b/src/generation/chat/request.rs new file mode 100644 index 0000000..bbc9d8b --- /dev/null +++ b/src/generation/chat/request.rs @@ -0,0 +1,49 @@ +use serde::Serialize; + +use crate::generation::{format::FormatType, options::GenerationOptions}; + +use super::ChatMessage; + +/// A chat message request to Ollama. +#[derive(Debug, Clone, Serialize)] +pub struct ChatMessageRequest { + #[serde(rename = "model")] + pub model_name: String, + pub messages: Vec, + pub options: Option, + pub template: Option, + pub format: Option, + pub(crate) stream: bool, +} + +impl ChatMessageRequest { + pub fn new(model_name: String, messages: Vec) -> Self { + Self { + model_name, + messages, + options: None, + template: None, + format: None, + // Stream value will be overwritten by Ollama::send_chat_messages_stream() and Ollama::send_chat_messages() methods + stream: false, + } + } + + /// Additional model parameters listed in the documentation for the Modelfile + pub fn options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); + self + } + + /// The full prompt or prompt template (overrides what is defined in the Modelfile) + pub fn template(mut self, template: String) -> Self { + self.template = Some(template); + self + } + + // The format to return a response in. Currently the only accepted value is `json` + pub fn format(mut self, format: FormatType) -> Self { + self.format = Some(format); + self + } +} diff --git a/src/generation/completion/mod.rs b/src/generation/completion/mod.rs index bb39cce..74445f3 100644 --- a/src/generation/completion/mod.rs +++ b/src/generation/completion/mod.rs @@ -110,8 +110,6 @@ pub struct GenerationFinalResponseData { pub context: GenerationContext, /// Time spent generating the response pub total_duration: u64, - /// Time spent in nanoseconds loading the model - pub load_duration: u64, /// Number of tokens in the prompt pub prompt_eval_count: u16, /// Time spent in nanoseconds evaluating the prompt diff --git a/src/generation/completion/request.rs b/src/generation/completion/request.rs index b7c5115..1f23e97 100644 --- a/src/generation/completion/request.rs +++ b/src/generation/completion/request.rs @@ -1,6 +1,6 @@ -use serde::{Deserialize, Serialize}; +use serde::Serialize; -use crate::generation::options::GenerationOptions; +use crate::generation::{format::FormatType, options::GenerationOptions}; use super::GenerationContext; @@ -14,17 +14,10 @@ pub struct GenerationRequest { pub system: Option, pub template: Option, pub context: Option, - pub format: Option, + pub format: Option, pub(crate) stream: bool, } -/// The format to return a response in. Currently the only accepted value is `json` -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum FormatEnum { - Json, -} - impl GenerationRequest { pub fn new(model_name: String, prompt: String) -> Self { Self { @@ -65,7 +58,7 @@ impl GenerationRequest { } // The format to return a response in. Currently the only accepted value is `json` - pub fn format(mut self, format: FormatEnum) -> Self { + pub fn format(mut self, format: FormatType) -> Self { self.format = Some(format); self } diff --git a/src/generation/format.rs b/src/generation/format.rs new file mode 100644 index 0000000..65bcb09 --- /dev/null +++ b/src/generation/format.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +/// The format to return a response in. Currently the only accepted value is `json` +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum FormatType { + Json, +} diff --git a/tests/generation.rs b/tests/generation.rs index e9f8430..52ad195 100644 --- a/tests/generation.rs +++ b/tests/generation.rs @@ -1,10 +1,7 @@ #![allow(unused_imports)] use ollama_rs::{ - generation::completion::{ - request::{FormatEnum, GenerationRequest}, - GenerationResponseStream, - }, + generation::completion::{request::GenerationRequest, GenerationResponseStream}, Ollama, }; use tokio::io::AsyncWriteExt; @@ -28,6 +25,7 @@ async fn test_generation_stream() { let mut done = false; while let Some(res) = res.next().await { let res = res.unwrap(); + dbg!(&res); if res.done { done = true; break; @@ -41,11 +39,12 @@ async fn test_generation_stream() { async fn test_generation() { let ollama = Ollama::default(); - let _ = ollama + let res = ollama .generate(GenerationRequest::new( "llama2:latest".to_string(), PROMPT.into(), )) .await .unwrap(); + dbg!(res); } diff --git a/tests/send_chat_messages.rs b/tests/send_chat_messages.rs new file mode 100644 index 0000000..c9681a4 --- /dev/null +++ b/tests/send_chat_messages.rs @@ -0,0 +1,51 @@ +use ollama_rs::{ + generation::chat::{request::ChatMessageRequest, ChatMessage}, + Ollama, +}; +use tokio_stream::StreamExt; + +#[allow(dead_code)] +const PROMPT: &str = "Why is the sky blue?"; + +#[tokio::test] +async fn test_send_chat_messages_stream() { + let ollama = Ollama::default(); + + let messages = vec![ChatMessage::user(PROMPT.to_string())]; + let mut res = ollama + .send_chat_messages_stream(ChatMessageRequest::new( + "llama2:latest".to_string(), + messages, + )) + .await + .unwrap(); + + let mut done = false; + while let Some(res) = res.next().await { + let res = res.unwrap(); + dbg!(&res); + if res.done { + done = true; + break; + } + } + + assert!(done); +} + +#[tokio::test] +async fn test_send_chat_messages() { + let ollama = Ollama::default(); + + let messages = vec![ChatMessage::user(PROMPT.to_string())]; + let res = ollama + .send_chat_messages(ChatMessageRequest::new( + "llama2:latest".to_string(), + messages, + )) + .await + .unwrap(); + dbg!(&res); + + assert!(res.done); +}