Skip to content

Commit

Permalink
Merge pull request #7 from pepperoni21/chat-feature
Browse files Browse the repository at this point in the history
Added support for /chat endpoint
  • Loading branch information
pepperoni21 authored Dec 11, 2023
2 parents 0bd7141 + 41c64d3 commit 7d72bda
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 18 deletions.
File renamed without changes.
52 changes: 52 additions & 0 deletions examples/chat_api_chatbot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
Ollama,
};
use tokio::io::{stdout, AsyncWriteExt};
use tokio_stream::StreamExt;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let ollama = Ollama::default();

let mut stdout = stdout();

let mut messages: Vec<ChatMessage> = vec![];

loop {
stdout.write_all(b"\n> ").await?;
stdout.flush().await?;

let mut input = String::new();
std::io::stdin().read_line(&mut input)?;

let input = input.trim_end();
if input.eq_ignore_ascii_case("exit") {
break;
}

let user_message = ChatMessage::user(input.to_string());
messages.push(user_message);

let mut stream: ChatMessageResponseStream = ollama
.send_chat_messages_stream(ChatMessageRequest::new(
"llama2:latest".to_string(),
messages.clone(),
))
.await?;

let mut response = String::new();
while let Some(Ok(res)) = stream.next().await {
if let Some(assistant_message) = res.message {
stdout
.write_all(assistant_message.content.as_bytes())
.await?;
stdout.flush().await?;
response += assistant_message.content.as_str();
}
}
messages.push(ChatMessage::assistant(response));
}

Ok(())
}
2 changes: 2 additions & 0 deletions src/generation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod chat;
pub mod completion;
pub mod embeddings;
pub mod format;
pub mod options;
154 changes: 154 additions & 0 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use serde::{Deserialize, Serialize};

use crate::Ollama;

pub mod request;

use request::ChatMessageRequest;

#[cfg(feature = "stream")]
/// A stream of `ChatMessageResponse` objects
pub type ChatMessageResponseStream =
std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<ChatMessageResponse, ()>>>>;

impl Ollama {
#[cfg(feature = "stream")]
/// Chat message generation with streaming.
/// Returns a stream of `ChatMessageResponse` objects
pub async fn send_chat_messages_stream(
&self,
request: ChatMessageRequest,
) -> crate::error::Result<ChatMessageResponseStream> {
use tokio_stream::StreamExt;

let mut request = request;
request.stream = true;

let uri = format!("{}/api/chat", self.uri());
let serialized = serde_json::to_string(&request)
.map_err(|e| e.to_string())
.unwrap();
let res = self
.reqwest_client
.post(uri)
.body(serialized)
.send()
.await
.map_err(|e| e.to_string())?;

if !res.status().is_success() {
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
}

let stream = Box::new(res.bytes_stream().map(|res| match res {
Ok(bytes) => {
let res = serde_json::from_slice::<ChatMessageResponse>(&bytes);
match res {
Ok(res) => Ok(res),
Err(e) => {
eprintln!("Failed to deserialize response: {}", e);
Err(())
}
}
}
Err(e) => {
eprintln!("Failed to read response: {}", e);
Err(())
}
}));

Ok(std::pin::Pin::from(stream))
}

/// Chat message generation.
/// Returns a `ChatMessageResponse` object
pub async fn send_chat_messages(
&self,
request: ChatMessageRequest,
) -> crate::error::Result<ChatMessageResponse> {
let mut request = request;
request.stream = false;

let uri = format!("{}/api/chat", self.uri());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(uri)
.body(serialized)
.send()
.await
.map_err(|e| e.to_string())?;

if !res.status().is_success() {
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
}

let bytes = res.bytes().await.map_err(|e| e.to_string())?;
let res =
serde_json::from_slice::<ChatMessageResponse>(&bytes).map_err(|e| e.to_string())?;

Ok(res)
}
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatMessageResponse {
/// The name of the model used for the completion.
pub model: String,
/// The creation time of the completion, in such format: `2023-08-04T08:52:19.385406455-07:00`.
pub created_at: String,
/// The generated chat message.
pub message: Option<ChatMessage>,
pub done: bool,
#[serde(flatten)]
/// The final data of the completion. This is only present if the completion is done.
pub final_data: Option<ChatMessageFinalResponseData>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatMessageFinalResponseData {
/// Time spent generating the response
pub total_duration: u64,
/// Number of tokens in the prompt
pub prompt_eval_count: u16,
/// Time spent in nanoseconds evaluating the prompt
pub prompt_eval_duration: u64,
/// Number of tokens the response
pub eval_count: u16,
/// Time in nanoseconds spent generating the response
pub eval_duration: u64,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
}

impl ChatMessage {
pub fn new(role: MessageRole, content: String) -> Self {
Self { role, content }
}

pub fn user(content: String) -> Self {
Self::new(MessageRole::User, content)
}

pub fn assistant(content: String) -> Self {
Self::new(MessageRole::Assistant, content)
}

pub fn system(content: String) -> Self {
Self::new(MessageRole::System, content)
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageRole {
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "system")]
System,
}
49 changes: 49 additions & 0 deletions src/generation/chat/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use serde::Serialize;

use crate::generation::{format::FormatType, options::GenerationOptions};

use super::ChatMessage;

/// A chat message request to Ollama.
#[derive(Debug, Clone, Serialize)]
pub struct ChatMessageRequest {
#[serde(rename = "model")]
pub model_name: String,
pub messages: Vec<ChatMessage>,
pub options: Option<GenerationOptions>,
pub template: Option<String>,
pub format: Option<FormatType>,
pub(crate) stream: bool,
}

impl ChatMessageRequest {
pub fn new(model_name: String, messages: Vec<ChatMessage>) -> Self {
Self {
model_name,
messages,
options: None,
template: None,
format: None,
// Stream value will be overwritten by Ollama::send_chat_messages_stream() and Ollama::send_chat_messages() methods
stream: false,
}
}

/// Additional model parameters listed in the documentation for the Modelfile
pub fn options(mut self, options: GenerationOptions) -> Self {
self.options = Some(options);
self
}

/// The full prompt or prompt template (overrides what is defined in the Modelfile)
pub fn template(mut self, template: String) -> Self {
self.template = Some(template);
self
}

// The format to return a response in. Currently the only accepted value is `json`
pub fn format(mut self, format: FormatType) -> Self {
self.format = Some(format);
self
}
}
2 changes: 0 additions & 2 deletions src/generation/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ pub struct GenerationFinalResponseData {
pub context: GenerationContext,
/// Time spent generating the response
pub total_duration: u64,
/// Time spent in nanoseconds loading the model
pub load_duration: u64,
/// Number of tokens in the prompt
pub prompt_eval_count: u16,
/// Time spent in nanoseconds evaluating the prompt
Expand Down
15 changes: 4 additions & 11 deletions src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};
use serde::Serialize;

use crate::generation::options::GenerationOptions;
use crate::generation::{format::FormatType, options::GenerationOptions};

use super::GenerationContext;

Expand All @@ -14,17 +14,10 @@ pub struct GenerationRequest {
pub system: Option<String>,
pub template: Option<String>,
pub context: Option<GenerationContext>,
pub format: Option<FormatEnum>,
pub format: Option<FormatType>,
pub(crate) stream: bool,
}

/// The format to return a response in. Currently the only accepted value is `json`
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum FormatEnum {
Json,
}

impl GenerationRequest {
pub fn new(model_name: String, prompt: String) -> Self {
Self {
Expand Down Expand Up @@ -65,7 +58,7 @@ impl GenerationRequest {
}

// The format to return a response in. Currently the only accepted value is `json`
pub fn format(mut self, format: FormatEnum) -> Self {
pub fn format(mut self, format: FormatType) -> Self {
self.format = Some(format);
self
}
Expand Down
8 changes: 8 additions & 0 deletions src/generation/format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};

/// The format to return a response in. Currently the only accepted value is `json`
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum FormatType {
Json,
}
9 changes: 4 additions & 5 deletions tests/generation.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#![allow(unused_imports)]

use ollama_rs::{
generation::completion::{
request::{FormatEnum, GenerationRequest},
GenerationResponseStream,
},
generation::completion::{request::GenerationRequest, GenerationResponseStream},
Ollama,
};
use tokio::io::AsyncWriteExt;
Expand All @@ -28,6 +25,7 @@ async fn test_generation_stream() {
let mut done = false;
while let Some(res) = res.next().await {
let res = res.unwrap();
dbg!(&res);
if res.done {
done = true;
break;
Expand All @@ -41,11 +39,12 @@ async fn test_generation_stream() {
async fn test_generation() {
let ollama = Ollama::default();

let _ = ollama
let res = ollama
.generate(GenerationRequest::new(
"llama2:latest".to_string(),
PROMPT.into(),
))
.await
.unwrap();
dbg!(res);
}
Loading

0 comments on commit 7d72bda

Please sign in to comment.