Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama Function Calling #51

Merged
merged 19 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
593 changes: 584 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@ readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
reqwest = { version = "0.12.4", default-features = false }
reqwest = { version = "0.12.4", default-features = false, features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["full"], optional = true }
tokio-stream = { version = "0.1.15", optional = true }
async-trait = { version = "0.1.73" }
url = "2"
log = "0.4"
scraper = { version = "0.19.0", optional = true }
text-splitter = { version = "0.13.1", optional = true }
regex = { version = "1.9.3", optional = true }

[features]
default = ["reqwest/default-tls"]
stream = ["tokio-stream", "reqwest/stream", "tokio"]
rustls = ["reqwest/rustls-tls"]
chat-history = []
function-calling = ["scraper", "text-splitter", "regex", "chat-history"]

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,21 @@ let res = ollama.generate_embeddings("llama2:latest".to_string(), prompt, None).
```

_Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vector of floats)._

### Make a function call

```rust
let tools = vec![Arc::new(Scraper::new())];
let parser = Arc::new(NousFunctionCall::new());
let message = ChatMessage::user("What is the current oil price?".to_string());
let res = ollama.send_function_call(
FunctionCallRequest::new(
"adrienbrault/nous-hermes2pro:Q8_0".to_string(),
tools,
vec![message],
),
parser,
).await.unwrap();
```

_Uses the given tools (such as searching the web) to find an answer, returns a `ChatMessageResponse` with the answer to the question._
16 changes: 16 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ impl From<String> for OllamaError {
Self { message }
}
}

impl From<Box<dyn Error>> for OllamaError {
fn from(error: Box<dyn Error>) -> Self {
Self {
message: error.to_string(),
}
}
}

impl From<serde_json::Error> for OllamaError {
fn from(error: serde_json::Error) -> Self {
Self {
message: error.to_string(),
}
}
}
2 changes: 2 additions & 0 deletions src/generation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod chat;
pub mod completion;
pub mod embeddings;
#[cfg(feature = "function-calling")]
pub mod functions;
pub mod images;
pub mod options;
pub mod parameters;
5 changes: 1 addition & 4 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use serde::{Deserialize, Serialize};

use crate::Ollama;

pub mod request;

use request::ChatMessageRequest;

use super::images::Image;
use request::ChatMessageRequest;

#[cfg(feature = "chat-history")]
use crate::history::MessagesHistory;
Expand Down
106 changes: 106 additions & 0 deletions src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
pub mod pipelines;
pub mod request;
pub mod tools;

pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall;
pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall;
pub use crate::generation::functions::request::FunctionCallRequest;
pub use tools::DDGSearcher;
pub use tools::Scraper;
pub use tools::StockScraper;

use crate::error::OllamaError;
use crate::generation::chat::request::ChatMessageRequest;
use crate::generation::chat::{ChatMessage, ChatMessageResponse};
use crate::generation::functions::pipelines::RequestParserBase;
use crate::generation::functions::tools::Tool;
use std::sync::Arc;

#[cfg(feature = "function-calling")]
impl crate::Ollama {
fn has_system_prompt(&self, messages: &[ChatMessage], system_prompt: &str) -> bool {
let system_message = messages.first().unwrap().clone();
system_message.content == system_prompt
}

fn has_system_prompt_history(&mut self) -> bool {
return self.get_messages_history("default".to_string()).is_some();
}

#[cfg(feature = "chat-history")]
pub async fn send_function_call_with_history(
&mut self,
request: FunctionCallRequest,
parser: Arc<dyn RequestParserBase>,
id: String,
) -> Result<ChatMessageResponse, OllamaError> {
let mut request = request;

if !self.has_system_prompt_history() {
let system_prompt = parser.get_system_message(&request.tools).await;
self.set_system_response(id.clone(), system_prompt.content);

//format input
let formatted_query = ChatMessage::user(
parser.format_query(&request.chat.messages.first().unwrap().content),
);
//replace with formatted_query with previous chat_message
request.chat.messages.remove(0);
request.chat.messages.insert(0, formatted_query);
}

let tool_call_result = self
.send_chat_messages_with_history(
ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages),
id.clone(),
)
.await?;

let tool_call_content: String = tool_call_result.message.clone().unwrap().content;
let result = parser
.parse(
&tool_call_content,
request.chat.model_name.clone(),
request.tools,
)
.await;

match result {
Ok(r) => {
self.add_assistant_response(id.clone(), r.message.clone().unwrap().content);
Ok(r)
}
Err(e) => {
self.add_assistant_response(id.clone(), e.message.clone().unwrap().content);
Ok(e)
}
}
}

pub async fn send_function_call(
&self,
request: FunctionCallRequest,
parser: Arc<dyn RequestParserBase>,
) -> Result<ChatMessageResponse, OllamaError> {
let mut request = request;

request.chat.stream = false;
let system_prompt = parser.get_system_message(&request.tools).await;
let model_name = request.chat.model_name.clone();

//Make sure the first message in chat is the system prompt
if !self.has_system_prompt(&request.chat.messages, &system_prompt.content) {
request.chat.messages.insert(0, system_prompt);
}
let result = self.send_chat_messages(request.chat).await?;
let response_content: String = result.message.clone().unwrap().content;

let result = parser
.parse(&response_content, model_name, request.tools)
.await;
match result {
Ok(r) => Ok(r),
Err(e) => Ok(e),
}
}
}
26 changes: 26 additions & 0 deletions src/generation/functions/pipelines/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::error::OllamaError;
use crate::generation::chat::{ChatMessage, ChatMessageResponse};
use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use std::sync::Arc;

pub mod nous_hermes;
pub mod openai;

#[async_trait]
pub trait RequestParserBase {
async fn parse(
&self,
input: &str,
model_name: String,
tools: Vec<Arc<dyn Tool>>,
) -> Result<ChatMessageResponse, ChatMessageResponse>;
fn format_query(&self, input: &str) -> String {
input.to_string()
}
fn format_response(&self, response: &str) -> String {
response.to_string()
}
async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage;
fn error_handler(&self, error: OllamaError) -> ChatMessageResponse;
}
4 changes: 4 additions & 0 deletions src/generation/functions/pipelines/nous_hermes/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod prompts;
pub mod request;

pub use prompts::DEFAULT_SYSTEM_TEMPLATE;
67 changes: 67 additions & 0 deletions src/generation/functions/pipelines/nous_hermes/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#"
Role: |
You are a function calling AI agent with self-recursion.
You can call only one function at a time and analyse data you get from function response.
You are provided with function signatures within <tools></tools> XML tags.
The current date is: {date}.
Objective: |
You may use agentic frameworks for reasoning and planning to help with user query.
Please call a function and wait for function results to be provided to you in the next iteration.
Don't make assumptions about what values to plug into function arguments.
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
Analyze the data once you get the results and call another function.
At each iteration please continue adding the your analysis to previous summary.
Your final response should directly answer the user query with an anlysis or summary of the results of function calls.
Tools: |
Here are the available tools:
<tools> {tools} </tools>
If the provided function signatures doesn't have the function you must call, you may write executable rust code in markdown syntax and call code_interpreter() function as follows:
<tool_call>
{"arguments": {"code_markdown": <rust-code>, "name": "code_interpreter"}}
</tool_call>
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
Examples: |
Here are some example usage of functions:
[
{
"example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
},
{
"example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
}
]
Schema: |
Use the following pydantic model json schema for each tool call you will make:
{
"name": "tool name",
"description": "tool description",
"parameters": {
"type": "object",
"properties": {
"parameter1": {
"type": "string",
"description": "parameter description"
},
"parameter2": {
"type": "string",
"description": "parameter description"
}
},
"required": [
"parameter1",
"parameter2"
]
}
}
Instructions: |
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
If you plan to continue with analysis, always call another function.
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"arguments": <args-dict>, "name": <function-name>}
</tool_call>
"#;
Loading