Skip to content

Commit

Permalink
Merge pull request #55 from izyuumi/master
Browse files Browse the repository at this point in the history
Add streaming feature with chat history
  • Loading branch information
pepperoni21 authored Jun 26, 2024
2 parents 5f10610 + 5b9fef4 commit 953232a
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 0 deletions.
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ log = "0.4"
scraper = { version = "0.19.0", optional = true }
text-splitter = { version = "0.13.1", optional = true }
regex = { version = "1.9.3", optional = true }
async-stream = "0.3.5"

[features]
default = ["reqwest/default-tls"]
Expand Down
57 changes: 57 additions & 0 deletions examples/chat_with_history_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage},
Ollama,
};
use tokio::io::{stdout, AsyncWriteExt};
use tokio_stream::StreamExt;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut ollama = Ollama::new_default_with_history_async(30);

let mut stdout = stdout();

let chat_id = "default".to_string();

loop {
stdout.write_all(b"\n> ").await?;
stdout.flush().await?;

let mut input = String::new();
std::io::stdin().read_line(&mut input)?;

let input = input.trim_end();
if input.eq_ignore_ascii_case("exit") {
break;
}

let user_message = ChatMessage::user(input.to_string());

let mut stream = ollama
.send_chat_messages_with_history_stream(
ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]),
chat_id.clone(),
)
.await?;

let mut response = String::new();
while let Some(Ok(res)) = stream.next().await {
if let Some(assistant_message) = res.message {
stdout
.write_all(assistant_message.content.as_bytes())
.await?;
stdout.flush().await?;
response += assistant_message.content.as_str();
}
}
}

// Display whole history of messages
dbg!(
&ollama
.get_messages_history_async("default".to_string())
.await
);

Ok(())
}
92 changes: 92 additions & 0 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(all(feature = "chat-history", feature = "stream"))]
use async_stream::stream;
use serde::{Deserialize, Serialize};

use crate::Ollama;
Expand All @@ -7,6 +9,8 @@ use request::ChatMessageRequest;

#[cfg(feature = "chat-history")]
use crate::history::MessagesHistory;
#[cfg(all(feature = "chat-history", feature = "stream"))]
use crate::history_async::MessagesHistoryAsync;

#[cfg(feature = "stream")]
/// A stream of `ChatMessageResponse` objects
Expand Down Expand Up @@ -151,6 +155,94 @@ impl Ollama {
}
}

#[cfg(all(feature = "chat-history", feature = "stream"))]
impl Ollama {
async fn get_chat_messages_by_id_async(&mut self, id: String) -> Vec<ChatMessage> {
// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
self.messages_history_async
.as_mut()
.unwrap_or(&mut MessagesHistoryAsync::default())
.messages_by_id
.lock()
.await
.entry(id.clone())
.or_default()
.clone()
}

pub async fn store_chat_message_by_id_async(&mut self, id: String, message: ChatMessage) {
if let Some(messages_history_async) = self.messages_history_async.as_mut() {
messages_history_async.add_message(id, message).await;
}
}

pub async fn send_chat_messages_with_history_stream(
&mut self,
mut request: ChatMessageRequest,
id: String,
) -> crate::error::Result<ChatMessageResponseStream> {
use tokio_stream::StreamExt;

let (tx, mut rx) =
tokio::sync::mpsc::unbounded_channel::<Result<ChatMessageResponse, ()>>(); // create a channel for sending and receiving messages

let mut current_chat_messages = self.get_chat_messages_by_id_async(id.clone()).await;

if let Some(messaeg) = request.messages.first() {
current_chat_messages.push(messaeg.clone());
}

request.messages.clone_from(&current_chat_messages);

let mut stream = self.send_chat_messages_stream(request.clone()).await?;

let message_history_async = self.messages_history_async.clone();

tokio::spawn(async move {
let mut result = String::new();
while let Some(res) = rx.recv().await {
match res {
Ok(res) => {
if let Some(message) = res.message.clone() {
result += message.content.as_str();
}
}
Err(_) => {
break;
}
}
}

if let Some(message_history_async) = message_history_async {
message_history_async
.add_message(id.clone(), ChatMessage::assistant(result))
.await;
} else {
eprintln!("not using chat-history and stream features"); // this should not happen if the features are enabled
}
});

let s = stream! {
while let Some(res) = stream.next().await {
match res {
Ok(res) => {
if let Err(e) = tx.send(Ok(res.clone())) {
eprintln!("Failed to send response: {}", e);
};
yield Ok(res);
}
Err(_) => {
yield Err(());
}
}
}
};

Ok(Box::pin(s))
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageResponse {
/// The name of the model used for the completion.
Expand Down
141 changes: 141 additions & 0 deletions src/history_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;

use crate::{
generation::chat::{ChatMessage, MessageRole},
Ollama,
};

#[derive(Debug, Clone, Default)]
pub struct MessagesHistoryAsync {
pub(crate) messages_by_id: Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>,
pub(crate) messages_number_limit: u16,
}

impl MessagesHistoryAsync {
pub fn new(messages_number_limit: u16) -> Self {
Self {
messages_by_id: Arc::new(Mutex::new(HashMap::new())),
messages_number_limit: messages_number_limit.max(2),
}
}

pub async fn add_message(&self, entry_id: String, message: ChatMessage) {
let mut messages_lock = self.messages_by_id.lock().await;
let messages = messages_lock.entry(entry_id).or_default();

// Replacing the oldest message if the limit is reached
// The oldest message is the first one, unless it's a system message
if messages.len() >= self.messages_number_limit as usize {
let index_to_remove = messages
.first()
.map(|m| if m.role == MessageRole::System { 1 } else { 0 })
.unwrap_or(0);

messages.remove(index_to_remove);
}

if message.role == MessageRole::System {
messages.insert(0, message);
} else {
messages.push(message);
}
}

pub async fn get_messages(&self, entry_id: &str) -> Option<Vec<ChatMessage>> {
let messages_lock = self.messages_by_id.lock().await;
messages_lock.get(entry_id).cloned()
}

pub async fn clear_messages(&self, entry_id: &str) {
let mut messages_lock = self.messages_by_id.lock().await;
messages_lock.remove(entry_id);
}
}

impl Ollama {
/// Create default instance with chat history
pub fn new_default_with_history_async(messages_number_limit: u16) -> Self {
Self {
messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)),
..Default::default()
}
}

/// Create new instance with chat history
///
/// # Panics
///
/// Panics if the host is not a valid URL or if the URL cannot have a port.
pub fn new_with_history_async(
host: impl crate::IntoUrl,
port: u16,
messages_number_limit: u16,
) -> Self {
let mut url = host.into_url().unwrap();
url.set_port(Some(port)).unwrap();
Self::new_with_history_from_url(url, messages_number_limit)
}

/// Create new instance with chat history from a [`url::Url`].
#[inline]
pub fn new_with_history_from_url_async(url: url::Url, messages_number_limit: u16) -> Self {
Self {
url,
messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)),
..Default::default()
}
}

#[inline]
pub fn try_new_with_history_async(
url: impl crate::IntoUrl,
messages_number_limit: u16,
) -> Result<Self, url::ParseError> {
Ok(Self {
url: url.into_url()?,
messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)),
..Default::default()
})
}

/// Add AI's message to a history
pub async fn add_assistant_response_async(&mut self, entry_id: String, message: String) {
if let Some(messages_history) = self.messages_history_async.as_mut() {
messages_history
.add_message(entry_id, ChatMessage::assistant(message))
.await;
}
}

/// Add user's message to a history
pub async fn add_user_response_async(&mut self, entry_id: String, message: String) {
if let Some(messages_history) = self.messages_history_async.as_mut() {
messages_history
.add_message(entry_id, ChatMessage::user(message))
.await;
}
}

/// Set system prompt for chat history
pub async fn set_system_response_async(&mut self, entry_id: String, message: String) {
if let Some(messages_history) = self.messages_history_async.as_mut() {
messages_history
.add_message(entry_id, ChatMessage::system(message))
.await;
}
}

/// For tests purpose
/// Getting list of messages in a history
pub async fn get_messages_history_async(
&mut self,
entry_id: String,
) -> Option<Vec<ChatMessage>> {
if let Some(messages_history_async) = self.messages_history_async.as_mut() {
messages_history_async.get_messages(&entry_id).await
} else {
None
}
}
}
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ pub mod error;
pub mod generation;
#[cfg(feature = "chat-history")]
pub mod history;
#[cfg(all(feature = "chat-history", feature = "stream"))]
pub mod history_async;
pub mod models;

use url::Url;
Expand Down Expand Up @@ -69,6 +71,8 @@ pub struct Ollama {
pub(crate) reqwest_client: reqwest::Client,
#[cfg(feature = "chat-history")]
pub(crate) messages_history: Option<history::MessagesHistory>,
#[cfg(all(feature = "chat-history", feature = "stream"))]
pub(crate) messages_history_async: Option<history_async::MessagesHistoryAsync>,
}

impl Ollama {
Expand Down Expand Up @@ -145,6 +149,8 @@ impl Default for Ollama {
reqwest_client: reqwest::Client::new(),
#[cfg(feature = "chat-history")]
messages_history: None,
#[cfg(all(feature = "chat-history", feature = "stream"))]
messages_history_async: None,
}
}
}

0 comments on commit 953232a

Please sign in to comment.