diff --git a/Cargo.lock b/Cargo.lock index ec51abb..9f61ef8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -486,6 +486,7 @@ dependencies = [ "serde_json", "tokio", "tokio-stream", + "url", ] [[package]] @@ -563,9 +564,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project" @@ -634,9 +635,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" +checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10" dependencies = [ "base64", "bytes", @@ -1089,9 +1090,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", diff --git a/Cargo.toml b/Cargo.toml index db3acc0..880a015 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,12 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12.3", default-features = false } -serde = { version = "1.0.190", features = ["derive"] } -serde_json = "1.0.116" +reqwest = { version = "0.12.4", default-features = false } +serde = { version = "1", features = ["derive"] } +serde_json = "1" tokio = { version = "1", features = ["full"], optional = true } tokio-stream = { version = "0.1.15", optional = true } +url = "2" [features] default = ["reqwest/default-tls"] diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 967211d..9b8458e 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -29,13 +29,13 @@ impl Ollama { let mut request = request; request.stream = true; - let uri = format!("{}/api/chat", self.uri()); + let url = format!("{}api/chat", self.url_str()); let serialized = serde_json::to_string(&request) .map_err(|e| e.to_string()) .unwrap(); let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await @@ -74,11 +74,11 @@ impl Ollama { let mut request = request; request.stream = false; - let uri = format!("{}/api/chat", self.uri()); + let url = format!("{}api/chat", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/generation/completion/mod.rs b/src/generation/completion/mod.rs index 60d80e8..b81081a 100644 --- a/src/generation/completion/mod.rs +++ b/src/generation/completion/mod.rs @@ -30,11 +30,11 @@ impl Ollama { let mut request = request; request.stream = true; - let uri = format!("{}/api/generate", self.uri()); + let url = format!("{}api/generate", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await @@ -68,11 +68,11 @@ impl Ollama { let mut request = request; request.stream = false; - let uri = format!("{}/api/generate", self.uri()); + let url = format!("{}api/generate", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/generation/embeddings.rs b/src/generation/embeddings.rs index 2e10855..f06eeb9 100644 --- a/src/generation/embeddings.rs +++ b/src/generation/embeddings.rs @@ -20,11 +20,11 @@ impl Ollama { options, }; - let uri = format!("{}/api/embeddings", self.uri()); + let url = format!("{}api/embeddings", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/history.rs b/src/history.rs index 41fb32b..3525822 100644 --- a/src/history.rs +++ b/src/history.rs @@ -59,15 +59,42 @@ impl Ollama { } /// Create new instance with chat history - pub fn new_with_history(host: String, port: u16, messages_number_limit: u16) -> Self { + /// + /// # Panics + /// + /// Panics if the host is not a valid URL or if the URL cannot have a port. + pub fn new_with_history( + host: impl crate::IntoUrl, + port: u16, + messages_number_limit: u16, + ) -> Self { + let mut url = host.into_url().unwrap(); + url.set_port(Some(port)).unwrap(); + Self::new_with_history_from_url(url, messages_number_limit) + } + + /// Create new instance with chat history from a [`url::Url`]. + #[inline] + pub fn new_with_history_from_url(url: url::Url, messages_number_limit: u16) -> Self { Self { - host, - port, + url, messages_history: Some(MessagesHistory::new(messages_number_limit)), ..Default::default() } } + #[inline] + pub fn try_new_with_history( + url: impl crate::IntoUrl, + messages_number_limit: u16, + ) -> Result { + Ok(Self { + url: url.into_url()?, + messages_history: Some(MessagesHistory::new(messages_number_limit)), + ..Default::default() + }) + } + /// Add AI's message to a history pub fn add_assistant_response(&mut self, entry_id: String, message: String) { if let Some(messages_history) = self.messages_history.as_mut() { diff --git a/src/lib.rs b/src/lib.rs index 9f6aefd..18ec517 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,27 +4,136 @@ pub mod generation; pub mod history; pub mod models; +use url::Url; + +/// A trait to try to convert some type into a [`Url`]. +/// +/// This trait is "sealed", such that only types within ollama-rs can +/// implement it. +pub trait IntoUrl: IntoUrlSealed {} + +impl IntoUrl for Url {} +impl IntoUrl for String {} +impl<'a> IntoUrl for &'a str {} +impl<'a> IntoUrl for &'a String {} + +pub trait IntoUrlSealed { + fn into_url(self) -> Result; + + fn as_str(&self) -> &str; +} + +impl IntoUrlSealed for Url { + fn into_url(self) -> Result { + Ok(self) + } + + fn as_str(&self) -> &str { + self.as_str() + } +} + +impl<'a> IntoUrlSealed for &'a str { + fn into_url(self) -> Result { + Url::parse(self)?.into_url() + } + + fn as_str(&self) -> &str { + self + } +} + +impl<'a> IntoUrlSealed for &'a String { + fn into_url(self) -> Result { + (&**self).into_url() + } + + fn as_str(&self) -> &str { + self.as_ref() + } +} + +impl IntoUrlSealed for String { + fn into_url(self) -> Result { + (&*self).into_url() + } + + fn as_str(&self) -> &str { + self.as_ref() + } +} + #[derive(Debug, Clone)] pub struct Ollama { - pub(crate) host: String, - pub(crate) port: u16, + pub(crate) url: Url, pub(crate) reqwest_client: reqwest::Client, #[cfg(feature = "chat-history")] pub(crate) messages_history: Option, } impl Ollama { - pub fn new(host: String, port: u16) -> Self { + /// # Panics + /// + /// Panics if the host is not a valid URL or if the URL cannot have a port. + pub fn new(host: impl IntoUrl, port: u16) -> Self { + let mut url: Url = host.into_url().unwrap(); + url.set_port(Some(port)).unwrap(); + + Self::from_url(url) + } + + /// Tries to create new instance by converting `url` into [`Url`]. + #[inline] + pub fn try_new(url: impl IntoUrl) -> Result { + Ok(Self::from_url(url.into_url()?)) + } + + /// Create new instance from a [`Url`]. + #[inline] + pub fn from_url(url: Url) -> Self { Self { - host, - port, + url, ..Default::default() } } /// Returns the http URI of the Ollama instance + /// + /// # Panics + /// + /// Panics if the URL does not have a host. + #[inline] pub fn uri(&self) -> String { - format!("{}:{}", self.host, self.port) + self.url.host().unwrap().to_string() + } + + /// Returns the URL of the Ollama instance as a [`Url`]. + pub fn url(&self) -> &Url { + &self.url + } + + /// Returns the URL of the Ollama instance as a [str]. + /// + /// Syntax in pseudo-BNF: + /// + /// ```bnf + /// url = scheme ":" [ hierarchical | non-hierarchical ] [ "?" query ]? [ "#" fragment ]? + /// non-hierarchical = non-hierarchical-path + /// non-hierarchical-path = /* Does not start with "/" */ + /// hierarchical = authority? hierarchical-path + /// authority = "//" userinfo? host [ ":" port ]? + /// userinfo = username [ ":" password ]? "@" + /// hierarchical-path = [ "/" path-segment ]+ + /// ``` + #[inline] + pub fn url_str(&self) -> &str { + self.url.as_str() + } +} + +impl From for Ollama { + fn from(url: Url) -> Self { + Self::from_url(url) } } @@ -32,8 +141,7 @@ impl Default for Ollama { /// Returns a default Ollama instance with the host set to `http://127.0.0.1:11434`. fn default() -> Self { Self { - host: "http://127.0.0.1".to_string(), - port: 11434, + url: Url::parse("http://127.0.0.1:11434").unwrap(), reqwest_client: reqwest::Client::new(), #[cfg(feature = "chat-history")] messages_history: None, diff --git a/src/models/copy.rs b/src/models/copy.rs index b3b5e6d..65b7c16 100644 --- a/src/models/copy.rs +++ b/src/models/copy.rs @@ -14,11 +14,11 @@ impl Ollama { destination, }; - let uri = format!("{}/api/copy", self.uri()); + let url = format!("{}api/copy", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/models/create.rs b/src/models/create.rs index 5e00d6e..eb580a7 100644 --- a/src/models/create.rs +++ b/src/models/create.rs @@ -21,11 +21,11 @@ impl Ollama { request.stream = true; - let uri = format!("{}/api/create", self.uri()); + let url = format!("{}api/create", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await @@ -63,11 +63,11 @@ impl Ollama { &self, request: CreateModelRequest, ) -> crate::error::Result { - let uri = format!("{}/api/create", self.uri()); + let url = format!("{}api/create", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/models/delete.rs b/src/models/delete.rs index 83db488..ad25d76 100644 --- a/src/models/delete.rs +++ b/src/models/delete.rs @@ -7,11 +7,11 @@ impl Ollama { pub async fn delete_model(&self, model_name: String) -> crate::error::Result<()> { let request = DeleteModelRequest { model_name }; - let uri = format!("{}/api/delete", self.uri()); + let url = format!("{}api/delete", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .delete(uri) + .delete(url) .body(serialized) .send() .await diff --git a/src/models/list_local.rs b/src/models/list_local.rs index f2d1bec..f0c86db 100644 --- a/src/models/list_local.rs +++ b/src/models/list_local.rs @@ -6,10 +6,10 @@ use super::LocalModel; impl Ollama { pub async fn list_local_models(&self) -> crate::error::Result> { - let uri = format!("{}/api/tags", self.uri()); + let url = format!("{}api/tags", self.url_str()); let res = self .reqwest_client - .get(uri) + .get(url) .send() .await .map_err(|e| e.to_string())?; diff --git a/src/models/pull.rs b/src/models/pull.rs index 5ea87d0..b764a19 100644 --- a/src/models/pull.rs +++ b/src/models/pull.rs @@ -28,11 +28,11 @@ impl Ollama { stream: true, }; - let uri = format!("{}/api/pull", self.uri()); + let url = format!("{}api/pull", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await @@ -79,11 +79,11 @@ impl Ollama { stream: false, }; - let uri = format!("{}/api/pull", self.uri()); + let url = format!("{}api/pull", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/models/push.rs b/src/models/push.rs index 6dbe091..9ec592e 100644 --- a/src/models/push.rs +++ b/src/models/push.rs @@ -28,11 +28,11 @@ impl Ollama { stream: true, }; - let uri = format!("{}/api/push", self.uri()); + let url = format!("{}api/push", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await @@ -80,11 +80,11 @@ impl Ollama { stream: false, }; - let uri = format!("{}/api/push", self.uri()); + let url = format!("{}api/push", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await diff --git a/src/models/show_info.rs b/src/models/show_info.rs index d7b5baf..e32083f 100644 --- a/src/models/show_info.rs +++ b/src/models/show_info.rs @@ -7,12 +7,12 @@ use super::ModelInfo; impl Ollama { /// Show details about a model including modelfile, template, parameters, license, and system prompt. pub async fn show_model_info(&self, model_name: String) -> crate::error::Result { - let uri = format!("{}/api/show", self.uri()); + let url = format!("{}api/show", self.url_str()); let serialized = serde_json::to_string(&ModelInfoRequest { model_name }).map_err(|e| e.to_string())?; let res = self .reqwest_client - .post(uri) + .post(url) .body(serialized) .send() .await