Skip to content

Commit

Permalink
Merge pull request #46 from rellfy/message-roles
Browse files Browse the repository at this point in the history
Add Tool message role & payloads
  • Loading branch information
rellfy authored Dec 22, 2024
2 parents 39dd271 + 02e4eae commit ccaebcc
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "openai"
version = "1.0.0-alpha.17"
version = "1.0.0-alpha.18"
authors = ["Lorenzo Fontoura <[email protected]>", "valentinegb"]
edition = "2021"
description = "An unofficial Rust library for the OpenAI API."
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ async fn main() {
let mut messages = vec![ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You are a large language model built into a command line interface as an example of what the `openai` Rust library made by Valentine Briese can do.".to_string()),
name: None,
function_call: None,
..Default::default()
}];

loop {
Expand All @@ -28,8 +27,7 @@ async fn main() {
messages.push(ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some(user_message_content),
name: None,
function_call: None,
..Default::default()
});

let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", messages.clone())
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ async fn main() {
ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You are a helpful assistant.".to_string()),
name: None,
function_call: None,
..Default::default()
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some("Tell me a random crab fact".to_string()),
name: None,
function_call: None,
..Default::default()
},
];
let chat_completion = ChatCompletion::builder("gpt-4o", messages.clone())
Expand Down
6 changes: 2 additions & 4 deletions examples/chat_stream_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ async fn main() {
let mut messages = vec![ChatCompletionMessage {
role: ChatCompletionMessageRole::System,
content: Some("You're an AI that replies to each message verbosely.".to_string()),
name: None,
function_call: None,
..Default::default()
}];

loop {
Expand All @@ -30,8 +29,7 @@ async fn main() {
messages.push(ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some(user_message_content),
name: None,
function_call: None,
..Default::default()
});

let chat_stream = ChatCompletionDelta::builder("gpt-3.5-turbo", messages.clone())
Expand Down
133 changes: 132 additions & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct ChatCompletionChoiceDelta {
pub delta: ChatCompletionMessageDelta,
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)]
pub struct ChatCompletionMessage {
/// The role of the author of this message.
pub role: ChatCompletionMessageRole,
Expand All @@ -58,6 +58,18 @@ pub struct ChatCompletionMessage {
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCall>,
/// Tool call that this message is responding to.
/// Required if the role is `Tool`.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Tool calls that the assistant is requesting to invoke.
/// Can only be populated if the role is `Assistant`,
/// otherwise it should be empty.
#[serde(
skip_serializing_if = "<[_]>::is_empty",
default = "default_tool_calls_deserialization"
)]
pub tool_calls: Vec<ToolCall>,
}

/// Same as ChatCompletionMessage, but received during a response stream.
Expand All @@ -75,6 +87,40 @@ pub struct ChatCompletionMessageDelta {
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCallDelta>,
/// Tool call that this message is responding to.
/// Required if the role is `Tool`.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Tool calls that the assistant is requesting to invoke.
/// Can only be populated if the role is `Assistant`,
/// otherwise it should be empty.
#[serde(
skip_serializing_if = "<[_]>::is_empty",
default = "default_tool_calls_deserialization"
)]
pub tool_calls: Vec<ToolCall>,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCall {
/// The ID of the tool call.
pub id: String,
/// The type of the tool. Currently, only `function` is supported.
pub r#type: String,
/// The function that the model called.
pub function: ToolCallFunction,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCallFunction {
/// The name of the function to call.
pub name: String,
/// The arguments to call the function with, as generated by the model in
/// JSON format.
/// Note that the model does not always generate valid JSON, and may
/// hallucinate parameters not defined by your function schema.
/// Validate the arguments in your code before calling your function.
pub arguments: String,
}

#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -117,6 +163,7 @@ pub enum ChatCompletionMessageRole {
User,
Assistant,
Function,
Tool,
}

#[derive(Serialize, Builder, Debug, Clone)]
Expand Down Expand Up @@ -391,6 +438,8 @@ impl From<ChatCompletionDelta> for ChatCompletion {
content: choice.delta.content.clone(),
name: choice.delta.name.clone(),
function_call: choice.delta.function_call.clone().map(|f| f.into()),
tool_call_id: None,
tool_calls: Vec::new(),
},
})
.collect(),
Expand Down Expand Up @@ -469,6 +518,16 @@ fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
}
}

fn default_tool_calls_deserialization() -> Vec<ToolCall> {
Vec::new()
}

impl Default for ChatCompletionMessageRole {
fn default() -> Self {
Self::User
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -486,6 +545,8 @@ mod tests {
content: Some("Hello!".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -525,6 +586,8 @@ mod tests {
),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
// Determinism currently comes from temperature 0, not seed.
Expand Down Expand Up @@ -560,6 +623,8 @@ mod tests {
content: Some("Hello!".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -596,6 +661,8 @@ mod tests {
content: Some("What is the weather in Boston?".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}
]
).functions([ChatCompletionFunctionDefinition {
Expand Down Expand Up @@ -663,6 +730,8 @@ mod tests {
content: Some("Write an example JSON for a JWT header using RS256".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
}],
)
.temperature(0.0)
Expand Down Expand Up @@ -724,4 +793,66 @@ mod tests {
}
merged.unwrap().into()
}

#[tokio::test]
async fn chat_tool_response_completion() {
dotenv().ok();
let credentials = Credentials::from_env();

let chat_completion = ChatCompletion::builder(
"gpt-4o-mini",
[
ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some("What's 0.9102847*28456? reply in plain text please".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: Vec::new(),
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::Assistant,
content: Some("Let me calculate that for you.".to_string()),
name: None,
function_call: None,
tool_call_id: None,
tool_calls: vec![ToolCall {
id: "the_tool_call".to_string(),
r#type: "function".to_string(),
function: ToolCallFunction {
name: "mul".to_string(),
arguments: "not_required_to_be_valid_here".to_string(),
},
}],
},
ChatCompletionMessage {
role: ChatCompletionMessageRole::Tool,
content: Some("the result is 25903.061423199997".to_string()),
name: None,
function_call: None,
tool_call_id: Some("the_tool_call".to_owned()),
tool_calls: Vec::new(),
},
],
)
// Determinism currently comes from temperature 0, not seed.
.temperature(0.0)
.seed(1337u64)
.credentials(credentials)
.create()
.await
.unwrap();

assert_eq!(
chat_completion
.choices
.first()
.unwrap()
.message
.content
.as_ref()
.unwrap(),
"The result of 0.9102847 multiplied by 28456 is approximately 25903.06."
);
}
}

0 comments on commit ccaebcc

Please sign in to comment.