diff --git a/main.go b/main.go index 2c006c0..38acfa8 100644 --- a/main.go +++ b/main.go @@ -54,6 +54,18 @@ var ( "o1-preview", } + GEMINI_MODELS = []string{ + "gemini-1.0-pro", + + "gemini-1.5-pro", + "gemini-1.5-pro-exp-0827", + "gemini-1.5-flash", + + "gemma-2-2b-it", + "gemma-2-9b-it", + "gemma-2-27b-it", + } + // Default admin public key, it will be used unless --dkn-admin-public-key is given DKN_ADMIN_PUBLIC_KEY = "0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658" ) @@ -123,7 +135,7 @@ func main() { // if DKN_MODELS are still empty, pick model interactively if envvars["DKN_MODELS"] == "" || *pick_model { - pickedModels := utils.PickModels(OPENAI_MODELS, OLLAMA_MODELS) + pickedModels := utils.PickModels(OPENAI_MODELS, GEMINI_MODELS, OLLAMA_MODELS) if pickedModels == "" { fmt.Println("No valid model picked") utils.ExitWithDelay(1) @@ -141,6 +153,16 @@ func main() { envvars["OPENAI_API_KEY"] = apikey } + // check gemini api key + if utils.IsGeminiRequired(envvars["DKN_MODELS"], &GEMINI_MODELS) && envvars["GEMINI_API_KEY"] == "" { + apikey := utils.GetUserInput("Enter your Gemini API Key", true) + if apikey == "" { + fmt.Println("Invalid input, please place your GEMINI_API_KEY to .env file") + utils.ExitWithDelay(1) + } + envvars["GEMINI_API_KEY"] = apikey + } + // check ollama environment if utils.IsOllamaRequired(envvars["DKN_MODELS"], &OLLAMA_MODELS) { ollamaHost, ollamaPort := utils.HandleOllamaEnv(envvars["OLLAMA_HOST"], envvars["OLLAMA_PORT"]) diff --git a/utils/cmd.go b/utils/cmd.go index 64a6156..8e85126 100644 --- a/utils/cmd.go +++ b/utils/cmd.go @@ -130,15 +130,16 @@ func RunCommand(working_dir string, outputDest string, wait bool, timeout time.D return pid, nil } -// PickModels prompts the user to pick models from the available OpenAI and Ollama models. +// PickModels prompts the user to pick models from the available OpenAI, Google and Ollama models. // // Parameters: // - openai_models: A slice of available OpenAI model names. +// - gemini_models: A slice of available Gemini model names. // - ollama_models: A slice of available Ollama model names. // // Returns: // - string: A comma-separated string of selected model names. -func PickModels(openai_models, ollama_models []string) string { +func PickModels(openai_models, gemini_models, ollama_models []string) string { // column widths idWidth := 4 @@ -160,8 +161,13 @@ func PickModels(openai_models, ollama_models []string) string { provider := "OpenAI" fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model) } - for id, model := range ollama_models { + for id, model := range gemini_models { modelId := len(openai_models) + id + 1 + provider := "Google" + fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model) + } + for id, model := range ollama_models { + modelId := len(openai_models) + len(gemini_models) + id + 1 provider := "Ollama" fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model) } @@ -195,12 +201,19 @@ func PickModels(openai_models, ollama_models []string) string { picked_models_map[id] = true picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, openai_models[id-1]) } - } else if id > len(openai_models) && id <= len(ollama_models)+len(openai_models) { + } else if id > len(openai_models) && id <= len(gemini_models)+len(openai_models) { + // gemini model picked + if !picked_models_map[id] { + // if not already picked, add it to bin + picked_models_map[id] = true + picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, gemini_models[id-len(openai_models)-1]) + } + } else if id > len(openai_models)+len(gemini_models) && id <= len(ollama_models)+len(gemini_models)+len(openai_models) { // ollama model picked if !picked_models_map[id] { // if not already picked, add it to bin picked_models_map[id] = true - picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(openai_models)-1]) + picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(gemini_models)-len(openai_models)-1]) } } else { // out of index, invalid @@ -209,9 +222,8 @@ func PickModels(openai_models, ollama_models []string) string { } } if len(invalid_selections) != 0 { - fmt.Printf("Skipping the invalid selections: %s \n", FormatMapKeys(invalid_selections)) + fmt.Printf("Skipping the invalid selections: %s \n\n", FormatMapKeys(invalid_selections)) } - fmt.Printf("\n") return picked_models_str } diff --git a/utils/gemini.go b/utils/gemini.go new file mode 100644 index 0000000..29a4534 --- /dev/null +++ b/utils/gemini.go @@ -0,0 +1,24 @@ +package utils + +import "strings" + +// IsGeminiRequired checks if any of the picked models require Google Gemini by comparing them against a list of available Gemini models. +// +// Parameters: +// - picked_models: A comma-separated string of model names selected by the user. +// - gemini_models: A pointer to a slice of strings containing available Gemini model names. +// +// Returns: +// - bool: Returns true if any of the picked models match Gemini models, indicating that Gemini is required, otherwise false. +func IsGeminiRequired(picked_models string, gemini_models *[]string) bool { + required := false + for _, model := range strings.Split(picked_models, ",") { + for _, gemini_model := range *gemini_models { + if model == gemini_model { + required = true + break + } + } + } + return required +}