From 20fed22b8315840f00a34327fce4d93faeee0a12 Mon Sep 17 00:00:00 2001
From: selimseker <selim.seker00@gmail.com>
Date: Fri, 18 Oct 2024 13:41:29 +0300
Subject: [PATCH] add gemini models

---
 main.go         | 24 +++++++++++++++++++++++-
 utils/cmd.go    | 26 +++++++++++++++++++-------
 utils/gemini.go | 24 ++++++++++++++++++++++++
 3 files changed, 66 insertions(+), 8 deletions(-)
 create mode 100644 utils/gemini.go

diff --git a/main.go b/main.go
index 2c006c0..38acfa8 100644
--- a/main.go
+++ b/main.go
@@ -54,6 +54,18 @@ var (
 		"o1-preview",
 	}
 
+	GEMINI_MODELS = []string{
+		"gemini-1.0-pro",
+
+		"gemini-1.5-pro",
+		"gemini-1.5-pro-exp-0827",
+		"gemini-1.5-flash",
+
+		"gemma-2-2b-it",
+		"gemma-2-9b-it",
+		"gemma-2-27b-it",
+	}
+
 	// Default admin public key, it will be used unless --dkn-admin-public-key is given
 	DKN_ADMIN_PUBLIC_KEY = "0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658"
 )
@@ -123,7 +135,7 @@ func main() {
 
 	// if DKN_MODELS are still empty, pick model interactively
 	if envvars["DKN_MODELS"] == "" || *pick_model {
-		pickedModels := utils.PickModels(OPENAI_MODELS, OLLAMA_MODELS)
+		pickedModels := utils.PickModels(OPENAI_MODELS, GEMINI_MODELS, OLLAMA_MODELS)
 		if pickedModels == "" {
 			fmt.Println("No valid model picked")
 			utils.ExitWithDelay(1)
@@ -141,6 +153,16 @@ func main() {
 		envvars["OPENAI_API_KEY"] = apikey
 	}
 
+	// check gemini api key
+	if utils.IsGeminiRequired(envvars["DKN_MODELS"], &GEMINI_MODELS) && envvars["GEMINI_API_KEY"] == "" {
+		apikey := utils.GetUserInput("Enter your Gemini API Key", true)
+		if apikey == "" {
+			fmt.Println("Invalid input, please place your GEMINI_API_KEY to .env file")
+			utils.ExitWithDelay(1)
+		}
+		envvars["GEMINI_API_KEY"] = apikey
+	}
+
 	// check ollama environment
 	if utils.IsOllamaRequired(envvars["DKN_MODELS"], &OLLAMA_MODELS) {
 		ollamaHost, ollamaPort := utils.HandleOllamaEnv(envvars["OLLAMA_HOST"], envvars["OLLAMA_PORT"])
diff --git a/utils/cmd.go b/utils/cmd.go
index 64a6156..8e85126 100644
--- a/utils/cmd.go
+++ b/utils/cmd.go
@@ -130,15 +130,16 @@ func RunCommand(working_dir string, outputDest string, wait bool, timeout time.D
 	return pid, nil
 }
 
-// PickModels prompts the user to pick models from the available OpenAI and Ollama models.
+// PickModels prompts the user to pick models from the available OpenAI, Google and Ollama models.
 //
 // Parameters:
 //   - openai_models: A slice of available OpenAI model names.
+//   - gemini_models: A slice of available Gemini model names.
 //   - ollama_models: A slice of available Ollama model names.
 //
 // Returns:
 //   - string: A comma-separated string of selected model names.
-func PickModels(openai_models, ollama_models []string) string {
+func PickModels(openai_models, gemini_models, ollama_models []string) string {
 
 	// column widths
 	idWidth := 4
@@ -160,8 +161,13 @@ func PickModels(openai_models, ollama_models []string) string {
 		provider := "OpenAI"
 		fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
 	}
-	for id, model := range ollama_models {
+	for id, model := range gemini_models {
 		modelId := len(openai_models) + id + 1
+		provider := "Google"
+		fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
+	}
+	for id, model := range ollama_models {
+		modelId := len(openai_models) + len(gemini_models) + id + 1
 		provider := "Ollama"
 		fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
 	}
@@ -195,12 +201,19 @@ func PickModels(openai_models, ollama_models []string) string {
 				picked_models_map[id] = true
 				picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, openai_models[id-1])
 			}
-		} else if id > len(openai_models) && id <= len(ollama_models)+len(openai_models) {
+		} else if id > len(openai_models) && id <= len(gemini_models)+len(openai_models) {
+			// gemini model picked
+			if !picked_models_map[id] {
+				// if not already picked, add it to bin
+				picked_models_map[id] = true
+				picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, gemini_models[id-len(openai_models)-1])
+			}
+		} else if id > len(openai_models)+len(gemini_models) && id <= len(ollama_models)+len(gemini_models)+len(openai_models) {
 			// ollama model picked
 			if !picked_models_map[id] {
 				// if not already picked, add it to bin
 				picked_models_map[id] = true
-				picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(openai_models)-1])
+				picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(gemini_models)-len(openai_models)-1])
 			}
 		} else {
 			// out of index, invalid
@@ -209,9 +222,8 @@ func PickModels(openai_models, ollama_models []string) string {
 		}
 	}
 	if len(invalid_selections) != 0 {
-		fmt.Printf("Skipping the invalid selections: %s \n", FormatMapKeys(invalid_selections))
+		fmt.Printf("Skipping the invalid selections: %s \n\n", FormatMapKeys(invalid_selections))
 	}
-	fmt.Printf("\n")
 	return picked_models_str
 }
 
diff --git a/utils/gemini.go b/utils/gemini.go
new file mode 100644
index 0000000..29a4534
--- /dev/null
+++ b/utils/gemini.go
@@ -0,0 +1,24 @@
+package utils
+
+import "strings"
+
+// IsGeminiRequired checks if any of the picked models require Google Gemini by comparing them against a list of available Gemini models.
+//
+// Parameters:
+//   - picked_models: A comma-separated string of model names selected by the user.
+//   - gemini_models: A pointer to a slice of strings containing available Gemini model names.
+//
+// Returns:
+//   - bool: Returns true if any of the picked models match Gemini models, indicating that Gemini is required, otherwise false.
+func IsGeminiRequired(picked_models string, gemini_models *[]string) bool {
+	required := false
+	for _, model := range strings.Split(picked_models, ",") {
+		for _, gemini_model := range *gemini_models {
+			if model == gemini_model {
+				required = true
+				break
+			}
+		}
+	}
+	return required
+}