Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Started implementation of temperature, top_k and logits for local models #31

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/api_interface/gem_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ pub struct GeminiExecutor {
api_key: String,
client: Client,
max_tokens: i32,
temperature: f64, // Add temperature parameter
top_k: i32, // Add top_k parameter
logits: bool, // Add logits parameter
}

impl GeminiExecutor {
pub fn new(model: String, api_key: String, max_tokens: i32) -> Self {
pub fn new(model: String, api_key: String, max_tokens: i32, temperature: f64, top_k: i32, logits: bool) -> Self {
Self {
model,
api_key,
client: Client::new(),
max_tokens,
temperature, // Initialize temperature
top_k, // Initialize top_k
logits, // Initialize logits
}
}

Expand All @@ -36,10 +42,10 @@ impl GeminiExecutor {
);

let mut generation_config = json!({
"temperature": 1.0,
"temperature": self.temperature, // Use parameter
"maxOutputTokens": self.max_tokens,
"topP": 0.8,
"topK": 10
"topP": if self.logits { 0.9 } else { 0.8 }, // Conditional based on logits
"topK": self.top_k // Use parameter
});

let contents: Vec<Value> = input
Expand Down
3 changes: 3 additions & 0 deletions src/program/atomics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ pub struct Config {
pub custom_tools: Option<Vec<CustomToolTemplate>>,
/// Maximum number of tokens for LLMs to generate per run.
pub max_tokens: Option<i32>,
pub temperature: Option<f64>, // Add temperature field
pub top_k: Option<i32>, // Add top_k field
pub logits: Option<bool>, // Add logits field
}

#[derive(Debug, Deserialize)]
Expand Down
36 changes: 31 additions & 5 deletions src/program/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ impl Executor {
ModelProvider::Gemini => {
let api_key = std::env::var("GEMINI_API_KEY").expect("$GEMINI_API_KEY is not set");
let max_tokens = config.max_tokens.unwrap_or(800);
let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens);
let temperature = config.temperature.unwrap_or(1.0); // Default value for temperature
let top_k = config.top_k.unwrap_or(10); // Default value for top_k
let logits = config.logits.unwrap_or(false); // Default value for logits
let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens, temperature, top_k, logits);
executor.generate_text(input, schema).await?
}
ModelProvider::OpenRouter => {
Expand Down Expand Up @@ -649,10 +652,13 @@ impl Executor {
ModelProvider::Gemini => {
let api_key = std::env::var("GEMINI_API_KEY").expect("$GEMINI_API_KEY is not set");
let max_tokens = config.max_tokens.unwrap_or(800);
let temperature = config.temperature.unwrap_or(1.0); // Default value for temperature
let top_k = config.top_k.unwrap_or(10); // Default value for top_k
let logits = config.logits.unwrap_or(false); // Default value for logits
match self.model{
Model::Gemini15Flash | Model::Gemini15Pro => {
let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens);
executor
let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens, temperature, top_k, logits);
executor
.function_call(prompt, tools, raw_mode, oai_parser)
.await?
}
Expand Down Expand Up @@ -711,7 +717,6 @@ impl Executor {
entries[index].clone()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -728,4 +733,25 @@ mod tests {

executor.pull_model().await.expect("should pull model");
}
}

#[tokio::test]
async fn test_generate_text_with_config() {
let config = Config {
max_steps: 10,
max_time: 60,
tools: vec![],
custom_tools: None,
max_tokens: Some(800),
temperature: Some(0.7),
top_k: Some(50),
logits: Some(true),
};

let executor = Executor::new(Model::Gemini15Flash);
let input = vec![MessageInput::new_user_message("Hello, world!")];
let schema = None;

let result = executor.generate_text(input, &schema, &config).await;
assert!(result.is_ok());
}
}