Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tool message role & payloads #46

Merged
merged 3 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "openai"
version = "1.0.0-alpha.17"
version = "1.0.0-alpha.18"
authors = ["Lorenzo Fontoura <[email protected]>", "valentinegb"]
edition = "2021"
description = "An unofficial Rust library for the OpenAI API."
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ async fn main() {
let mut messages = vec![ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You are a large language model built into a command line interface as an example of what the `openai` Rust library made by Valentine Briese can do.".to_string()),
name: None,
function_call: None,
..Default::default()
}];

loop {
Expand All @@ -28,8 +27,7 @@ async fn main() {
messages.push(ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some(user_message_content),
name: None,
function_call: None,
..Default::default()
});

let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", messages.clone())
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ async fn main() {
ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You are a helpful assistant.".to_string()),
name: None,
function_call: None,
..Default::default()
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some("Tell me a random crab fact".to_string()),
name: None,
function_call: None,
..Default::default()
},
];
let chat_completion = ChatCompletion::builder("gpt-4o", messages.clone())
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_stream_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ async fn main() {
let mut messages = vec![ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You're an AI that replies to each message verbosely.".to_string()),
name: None,
function_call: None,
..Default::default()
}];

loop {
Expand All @@ -30,8 +29,7 @@ async fn main() {
messages.push(ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some(user_message_content),
name: None,
function_call: None,
..Default::default()
});

let chat_stream = ChatCompletionDelta::builder("gpt-3.5-turbo", messages.clone())
Expand Down
133 changes: 132 additions & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct ChatCompletionChoiceDelta {
pub delta: ChatCompletionMessageDelta,
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)]
pub struct ChatCompletionMessage {
/// The role of the author of this message.
pub role: ChatCompletionMessageRole,
Expand All @@ -58,6 +58,18 @@ pub struct ChatCompletionMessage {
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCall>,
/// Tool call that this message is responding to.
/// Required if the role is `Tool`.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Tool calls that the assistant is requesting to invoke.
/// Can only be populated if the role is `Assistant`,
/// otherwise it should be empty.
#[serde(
skip_serializing_if = "<[_]>::is_empty",
default = "default_tool_calls_deserialization"
)]
pub tool_calls: Vec<ToolCall>,
}

/// Same as ChatCompletionMessage, but received during a response stream.
Expand All @@ -75,6 +87,40 @@ pub struct ChatCompletionMessageDelta {
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCallDelta>,
/// Tool call that this message is responding to.
/// Required if the role is `Tool`.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Tool calls that the assistant is requesting to invoke.
/// Can only be populated if the role is `Assistant`,
/// otherwise it should be empty.
#[serde(
skip_serializing_if = "<[_]>::is_empty",
default = "default_tool_calls_deserialization"
)]
pub tool_calls: Vec<ToolCall>,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCall {
/// The ID of the tool call.
pub id: String,
/// The type of the tool. Currently, only `function` is supported.
pub r#type: String,
/// The function that the model called.
pub function: ToolCallFunction,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCallFunction {
/// The name of the function to call.
pub name: String,
/// The arguments to call the function with, as generated by the model in
/// JSON format.
/// Note that the model does not always generate valid JSON, and may
/// hallucinate parameters not defined by your function schema.
/// Validate the arguments in your code before calling your function.
pub arguments: String,
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -117,6 +163,7 @@ pub enum ChatCompletionMessageRole {
User,
Assistant,
Function,
Tool,
}

#[derive(Serialize, Builder, Debug, Clone)]
Expand Down Expand Up @@ -391,6 +438,8 @@ impl From<ChatCompletionDelta> for ChatCompletion {
content: choice.delta.content.clone(),
name: choice.delta.name.clone(),
function_call: choice.delta.function_call.clone().map(|f| f.into()),
tool_call_id: None,
tool_calls: Vec::new(),
},
})
.collect(),
Expand Down Expand Up @@ -469,6 +518,16 @@ fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
}
}

fn default_tool_calls_deserialization() -> Vec<ToolCall> {
Vec::new()
}

impl Default for ChatCompletionMessageRole {
fn default() -> Self {
Self::User
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -486,6 +545,8 @@ mod tests {
content: Some("Hello!".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -525,6 +586,8 @@ mod tests {
),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
// Determinism currently comes from temperature 0, not seed.
Expand Down Expand Up @@ -560,6 +623,8 @@ mod tests {
content: Some("Hello!".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -596,6 +661,8 @@ mod tests {
content: Some("What is the weather in Boston?".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}
]
).functions([ChatCompletionFunctionDefinition {
Expand Down Expand Up @@ -663,6 +730,8 @@ mod tests {
content: Some("Write an example JSON for a JWT header using RS256".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -724,4 +793,66 @@ mod tests {
}
merged.unwrap().into()
}

#[tokio::test]
async fn chat_tool_response_completion() {
dotenv().ok();
let credentials = Credentials::from_env();

let chat_completion = ChatCompletion::builder(
"gpt-4o-mini",
[
ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some("What's 0.9102847*28456? reply in plain text please".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::Assistant,
content: Some("Let me calculate that for you.".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: vec![ToolCall {
id: "the_tool_call".to_string(),
r#type: "function".to_string(),
function: ToolCallFunction {
name: "mul".to_string(),
arguments: "not_required_to_be_valid_here".to_string(),
},
}],
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::Tool,
content: Some("the result is 25903.061423199997".to_string()),
name: None,
function_call: None,
tool_call_id: Some("the_tool_call".to_owned()),
tool_calls: Vec::new(),
},
],
)
// Determinism currently comes from temperature 0, not seed.
.temperature(0.0)
.seed(1337u64)
.credentials(credentials)
.create()
.await
.unwrap();

assert_eq!(
chat_completion
.choices
.first()
.unwrap()
.message
.content
.as_ref()
.unwrap(),
"The result of 0.9102847 multiplied by 28456 is approximately 25903.06."
);
}
}
Loading