diff --git a/src/api_interface/gem_api.rs b/src/api_interface/gem_api.rs index 10f726f..981979d 100644 --- a/src/api_interface/gem_api.rs +++ b/src/api_interface/gem_api.rs @@ -12,15 +12,21 @@ pub struct GeminiExecutor { api_key: String, client: Client, max_tokens: i32, + temperature: f64, // Add temperature parameter + top_k: i32, // Add top_k parameter + logits: bool, // Add logits parameter } impl GeminiExecutor { - pub fn new(model: String, api_key: String, max_tokens: i32) -> Self { + pub fn new(model: String, api_key: String, max_tokens: i32, temperature: f64, top_k: i32, logits: bool) -> Self { Self { model, api_key, client: Client::new(), max_tokens, + temperature, // Initialize temperature + top_k, // Initialize top_k + logits, // Initialize logits } } @@ -36,10 +42,10 @@ impl GeminiExecutor { ); let mut generation_config = json!({ - "temperature": 1.0, + "temperature": self.temperature, // Use parameter "maxOutputTokens": self.max_tokens, - "topP": 0.8, - "topK": 10 + "topP": if self.logits { 0.9 } else { 0.8 }, // Conditional based on logits + "topK": self.top_k // Use parameter }); let contents: Vec = input diff --git a/src/program/atomics.rs b/src/program/atomics.rs index 5933abe..b7ccf8b 100644 --- a/src/program/atomics.rs +++ b/src/program/atomics.rs @@ -64,6 +64,9 @@ pub struct Config { pub custom_tools: Option>, /// Maximum number of tokens for LLMs to generate per run. pub max_tokens: Option, + pub temperature: Option, // Add temperature field + pub top_k: Option, // Add top_k field + pub logits: Option, // Add logits field } #[derive(Debug, Deserialize)] diff --git a/src/program/executor.rs b/src/program/executor.rs index a4d9473..c18f5f7 100644 --- a/src/program/executor.rs +++ b/src/program/executor.rs @@ -574,7 +574,10 @@ impl Executor { ModelProvider::Gemini => { let api_key = std::env::var("GEMINI_API_KEY").expect("$GEMINI_API_KEY is not set"); let max_tokens = config.max_tokens.unwrap_or(800); - let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens); + let temperature = config.temperature.unwrap_or(1.0); // Default value for temperature + let top_k = config.top_k.unwrap_or(10); // Default value for top_k + let logits = config.logits.unwrap_or(false); // Default value for logits + let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens, temperature, top_k, logits); executor.generate_text(input, schema).await? } ModelProvider::OpenRouter => { @@ -649,10 +652,13 @@ impl Executor { ModelProvider::Gemini => { let api_key = std::env::var("GEMINI_API_KEY").expect("$GEMINI_API_KEY is not set"); let max_tokens = config.max_tokens.unwrap_or(800); + let temperature = config.temperature.unwrap_or(1.0); // Default value for temperature + let top_k = config.top_k.unwrap_or(10); // Default value for top_k + let logits = config.logits.unwrap_or(false); // Default value for logits match self.model{ Model::Gemini15Flash | Model::Gemini15Pro => { - let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens); - executor + let executor = GeminiExecutor::new(self.model.to_string(), api_key, max_tokens, temperature, top_k, logits); + executor .function_call(prompt, tools, raw_mode, oai_parser) .await? } @@ -711,7 +717,6 @@ impl Executor { entries[index].clone() } } - #[cfg(test)] mod tests { use super::*; @@ -728,4 +733,25 @@ mod tests { executor.pull_model().await.expect("should pull model"); } -} + + #[tokio::test] + async fn test_generate_text_with_config() { + let config = Config { + max_steps: 10, + max_time: 60, + tools: vec![], + custom_tools: None, + max_tokens: Some(800), + temperature: Some(0.7), + top_k: Some(50), + logits: Some(true), + }; + + let executor = Executor::new(Model::Gemini15Flash); + let input = vec![MessageInput::new_user_message("Hello, world!")]; + let schema = None; + + let result = executor.generate_text(input, &schema, &config).await; + assert!(result.is_ok()); + } +} \ No newline at end of file