-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from pepperoni21/chat-feature
Added support for /chat endpoint
- Loading branch information
Showing
10 changed files
with
324 additions
and
18 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use ollama_rs::{ | ||
generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, | ||
Ollama, | ||
}; | ||
use tokio::io::{stdout, AsyncWriteExt}; | ||
use tokio_stream::StreamExt; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
let ollama = Ollama::default(); | ||
|
||
let mut stdout = stdout(); | ||
|
||
let mut messages: Vec<ChatMessage> = vec![]; | ||
|
||
loop { | ||
stdout.write_all(b"\n> ").await?; | ||
stdout.flush().await?; | ||
|
||
let mut input = String::new(); | ||
std::io::stdin().read_line(&mut input)?; | ||
|
||
let input = input.trim_end(); | ||
if input.eq_ignore_ascii_case("exit") { | ||
break; | ||
} | ||
|
||
let user_message = ChatMessage::user(input.to_string()); | ||
messages.push(user_message); | ||
|
||
let mut stream: ChatMessageResponseStream = ollama | ||
.send_chat_messages_stream(ChatMessageRequest::new( | ||
"llama2:latest".to_string(), | ||
messages.clone(), | ||
)) | ||
.await?; | ||
|
||
let mut response = String::new(); | ||
while let Some(Ok(res)) = stream.next().await { | ||
if let Some(assistant_message) = res.message { | ||
stdout | ||
.write_all(assistant_message.content.as_bytes()) | ||
.await?; | ||
stdout.flush().await?; | ||
response += assistant_message.content.as_str(); | ||
} | ||
} | ||
messages.push(ChatMessage::assistant(response)); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
pub mod chat; | ||
pub mod completion; | ||
pub mod embeddings; | ||
pub mod format; | ||
pub mod options; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::Ollama; | ||
|
||
pub mod request; | ||
|
||
use request::ChatMessageRequest; | ||
|
||
#[cfg(feature = "stream")] | ||
/// A stream of `ChatMessageResponse` objects | ||
pub type ChatMessageResponseStream = | ||
std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<ChatMessageResponse, ()>>>>; | ||
|
||
impl Ollama { | ||
#[cfg(feature = "stream")] | ||
/// Chat message generation with streaming. | ||
/// Returns a stream of `ChatMessageResponse` objects | ||
pub async fn send_chat_messages_stream( | ||
&self, | ||
request: ChatMessageRequest, | ||
) -> crate::error::Result<ChatMessageResponseStream> { | ||
use tokio_stream::StreamExt; | ||
|
||
let mut request = request; | ||
request.stream = true; | ||
|
||
let uri = format!("{}/api/chat", self.uri()); | ||
let serialized = serde_json::to_string(&request) | ||
.map_err(|e| e.to_string()) | ||
.unwrap(); | ||
let res = self | ||
.reqwest_client | ||
.post(uri) | ||
.body(serialized) | ||
.send() | ||
.await | ||
.map_err(|e| e.to_string())?; | ||
|
||
if !res.status().is_success() { | ||
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); | ||
} | ||
|
||
let stream = Box::new(res.bytes_stream().map(|res| match res { | ||
Ok(bytes) => { | ||
let res = serde_json::from_slice::<ChatMessageResponse>(&bytes); | ||
match res { | ||
Ok(res) => Ok(res), | ||
Err(e) => { | ||
eprintln!("Failed to deserialize response: {}", e); | ||
Err(()) | ||
} | ||
} | ||
} | ||
Err(e) => { | ||
eprintln!("Failed to read response: {}", e); | ||
Err(()) | ||
} | ||
})); | ||
|
||
Ok(std::pin::Pin::from(stream)) | ||
} | ||
|
||
/// Chat message generation. | ||
/// Returns a `ChatMessageResponse` object | ||
pub async fn send_chat_messages( | ||
&self, | ||
request: ChatMessageRequest, | ||
) -> crate::error::Result<ChatMessageResponse> { | ||
let mut request = request; | ||
request.stream = false; | ||
|
||
let uri = format!("{}/api/chat", self.uri()); | ||
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; | ||
let res = self | ||
.reqwest_client | ||
.post(uri) | ||
.body(serialized) | ||
.send() | ||
.await | ||
.map_err(|e| e.to_string())?; | ||
|
||
if !res.status().is_success() { | ||
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); | ||
} | ||
|
||
let bytes = res.bytes().await.map_err(|e| e.to_string())?; | ||
let res = | ||
serde_json::from_slice::<ChatMessageResponse>(&bytes).map_err(|e| e.to_string())?; | ||
|
||
Ok(res) | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatMessageResponse { | ||
/// The name of the model used for the completion. | ||
pub model: String, | ||
/// The creation time of the completion, in such format: `2023-08-04T08:52:19.385406455-07:00`. | ||
pub created_at: String, | ||
/// The generated chat message. | ||
pub message: Option<ChatMessage>, | ||
pub done: bool, | ||
#[serde(flatten)] | ||
/// The final data of the completion. This is only present if the completion is done. | ||
pub final_data: Option<ChatMessageFinalResponseData>, | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatMessageFinalResponseData { | ||
/// Time spent generating the response | ||
pub total_duration: u64, | ||
/// Number of tokens in the prompt | ||
pub prompt_eval_count: u16, | ||
/// Time spent in nanoseconds evaluating the prompt | ||
pub prompt_eval_duration: u64, | ||
/// Number of tokens the response | ||
pub eval_count: u16, | ||
/// Time in nanoseconds spent generating the response | ||
pub eval_duration: u64, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct ChatMessage { | ||
pub role: MessageRole, | ||
pub content: String, | ||
} | ||
|
||
impl ChatMessage { | ||
pub fn new(role: MessageRole, content: String) -> Self { | ||
Self { role, content } | ||
} | ||
|
||
pub fn user(content: String) -> Self { | ||
Self::new(MessageRole::User, content) | ||
} | ||
|
||
pub fn assistant(content: String) -> Self { | ||
Self::new(MessageRole::Assistant, content) | ||
} | ||
|
||
pub fn system(content: String) -> Self { | ||
Self::new(MessageRole::System, content) | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub enum MessageRole { | ||
#[serde(rename = "user")] | ||
User, | ||
#[serde(rename = "assistant")] | ||
Assistant, | ||
#[serde(rename = "system")] | ||
System, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use serde::Serialize; | ||
|
||
use crate::generation::{format::FormatType, options::GenerationOptions}; | ||
|
||
use super::ChatMessage; | ||
|
||
/// A chat message request to Ollama. | ||
#[derive(Debug, Clone, Serialize)] | ||
pub struct ChatMessageRequest { | ||
#[serde(rename = "model")] | ||
pub model_name: String, | ||
pub messages: Vec<ChatMessage>, | ||
pub options: Option<GenerationOptions>, | ||
pub template: Option<String>, | ||
pub format: Option<FormatType>, | ||
pub(crate) stream: bool, | ||
} | ||
|
||
impl ChatMessageRequest { | ||
pub fn new(model_name: String, messages: Vec<ChatMessage>) -> Self { | ||
Self { | ||
model_name, | ||
messages, | ||
options: None, | ||
template: None, | ||
format: None, | ||
// Stream value will be overwritten by Ollama::send_chat_messages_stream() and Ollama::send_chat_messages() methods | ||
stream: false, | ||
} | ||
} | ||
|
||
/// Additional model parameters listed in the documentation for the Modelfile | ||
pub fn options(mut self, options: GenerationOptions) -> Self { | ||
self.options = Some(options); | ||
self | ||
} | ||
|
||
/// The full prompt or prompt template (overrides what is defined in the Modelfile) | ||
pub fn template(mut self, template: String) -> Self { | ||
self.template = Some(template); | ||
self | ||
} | ||
|
||
// The format to return a response in. Currently the only accepted value is `json` | ||
pub fn format(mut self, format: FormatType) -> Self { | ||
self.format = Some(format); | ||
self | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
/// The format to return a response in. Currently the only accepted value is `json` | ||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
#[serde(rename_all = "lowercase")] | ||
pub enum FormatType { | ||
Json, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.