Skip to content

Commit

Permalink
Added abort request for generate_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
pepperoni21 committed Jan 13, 2025
1 parent e509fc4 commit 16119f7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
2 changes: 2 additions & 0 deletions ollama-rs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum OllamaError {
ReqwestError(#[from] reqwest::Error),
#[error("Internal Ollama error")]
InternalError(InternalOllamaError),
#[error("Ollama aborted the request")]
Abort,
#[error("Error in Ollama")]
Other(String),
}
Expand Down
29 changes: 18 additions & 11 deletions ollama-rs/src/generation/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,25 @@ impl Ollama {
));
}

let stream = Box::new(res.bytes_stream().map(|res| match res {
Ok(bytes) => {
let res = serde_json::Deserializer::from_slice(&bytes).into_iter();
let res = res
.filter_map(Result::ok) // Filter out the errors
.collect::<Vec<GenerationResponse>>();
Ok(res)
let stream = Box::new(res.bytes_stream().map(move |res| {
if let Some(abort_signal) = request.abort_signal.as_ref() {
if abort_signal.aborted() {
return Err(OllamaError::Abort);
}
}
match res {
Ok(bytes) => {
let res = serde_json::Deserializer::from_slice(&bytes).into_iter();
let res = res
.filter_map(Result::ok) // Filter out the errors
.collect::<Vec<GenerationResponse>>();
Ok(res)
}
Err(e) => Err(OllamaError::Other(format!(
"Failed to read response: {}",
e
))),
}
Err(e) => Err(OllamaError::Other(format!(
"Failed to read response: {}",
e
))),
}));

Ok(std::pin::Pin::from(stream))
Expand Down
33 changes: 33 additions & 0 deletions ollama-rs/src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::{atomic::AtomicBool, Arc};

use serde::Serialize;

use crate::generation::{
Expand All @@ -8,6 +10,28 @@ use crate::generation::{

use super::GenerationContext;

#[derive(Debug, Clone)]
pub struct AbortSignal {
pub(crate) abort_signal: Arc<AtomicBool>,
}

impl AbortSignal {
pub fn new() -> Self {

Check failure on line 19 in ollama-rs/src/generation/completion/request.rs

View workflow job for this annotation

GitHub Actions / Formatting

you should consider adding a `Default` implementation for `AbortSignal`
Self {
abort_signal: Arc::new(AtomicBool::new(false)),
}
}

pub fn abort(&self) {
self.abort_signal
.store(true, std::sync::atomic::Ordering::Relaxed);
}

pub fn aborted(&self) -> bool {
self.abort_signal.load(std::sync::atomic::Ordering::Relaxed)
}
}

/// A generation request to Ollama.
#[derive(Debug, Clone, Serialize)]
pub struct GenerationRequest {
Expand All @@ -24,6 +48,8 @@ pub struct GenerationRequest {
pub format: Option<FormatType>,
pub keep_alive: Option<KeepAlive>,
pub(crate) stream: bool,
#[serde(skip)]
pub abort_signal: Option<AbortSignal>,
}

impl GenerationRequest {
Expand All @@ -41,6 +67,7 @@ impl GenerationRequest {
keep_alive: None,
// Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods
stream: false,
abort_signal: None,
}
}

Expand Down Expand Up @@ -103,4 +130,10 @@ impl GenerationRequest {
self.keep_alive = Some(keep_alive);
self
}

/// Sets the abort signal for the request
pub fn abort_signal(mut self, abort_signal: AbortSignal) -> Self {
self.abort_signal = Some(abort_signal);
self
}
}

0 comments on commit 16119f7

Please sign in to comment.