Skip to content

Commit

Permalink
Merge pull request #13 from firstbatchxyz/gemini-models
Browse files Browse the repository at this point in the history
gemini models
  • Loading branch information
selimseker authored Oct 18, 2024
2 parents 02a1ab6 + 20fed22 commit e91e364
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
24 changes: 23 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ var (
"o1-preview",
}

GEMINI_MODELS = []string{
"gemini-1.0-pro",

"gemini-1.5-pro",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash",

"gemma-2-2b-it",
"gemma-2-9b-it",
"gemma-2-27b-it",
}

// Default admin public key, it will be used unless --dkn-admin-public-key is given
DKN_ADMIN_PUBLIC_KEY = "0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110be6ae658"
)
Expand Down Expand Up @@ -123,7 +135,7 @@ func main() {

// if DKN_MODELS are still empty, pick model interactively
if envvars["DKN_MODELS"] == "" || *pick_model {
pickedModels := utils.PickModels(OPENAI_MODELS, OLLAMA_MODELS)
pickedModels := utils.PickModels(OPENAI_MODELS, GEMINI_MODELS, OLLAMA_MODELS)
if pickedModels == "" {
fmt.Println("No valid model picked")
utils.ExitWithDelay(1)
Expand All @@ -141,6 +153,16 @@ func main() {
envvars["OPENAI_API_KEY"] = apikey
}

// check gemini api key
if utils.IsGeminiRequired(envvars["DKN_MODELS"], &GEMINI_MODELS) && envvars["GEMINI_API_KEY"] == "" {
apikey := utils.GetUserInput("Enter your Gemini API Key", true)
if apikey == "" {
fmt.Println("Invalid input, please place your GEMINI_API_KEY to .env file")
utils.ExitWithDelay(1)
}
envvars["GEMINI_API_KEY"] = apikey
}

// check ollama environment
if utils.IsOllamaRequired(envvars["DKN_MODELS"], &OLLAMA_MODELS) {
ollamaHost, ollamaPort := utils.HandleOllamaEnv(envvars["OLLAMA_HOST"], envvars["OLLAMA_PORT"])
Expand Down
26 changes: 19 additions & 7 deletions utils/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,16 @@ func RunCommand(working_dir string, outputDest string, wait bool, timeout time.D
return pid, nil
}

// PickModels prompts the user to pick models from the available OpenAI and Ollama models.
// PickModels prompts the user to pick models from the available OpenAI, Google and Ollama models.
//
// Parameters:
// - openai_models: A slice of available OpenAI model names.
// - gemini_models: A slice of available Gemini model names.
// - ollama_models: A slice of available Ollama model names.
//
// Returns:
// - string: A comma-separated string of selected model names.
func PickModels(openai_models, ollama_models []string) string {
func PickModels(openai_models, gemini_models, ollama_models []string) string {

// column widths
idWidth := 4
Expand All @@ -160,8 +161,13 @@ func PickModels(openai_models, ollama_models []string) string {
provider := "OpenAI"
fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
}
for id, model := range ollama_models {
for id, model := range gemini_models {
modelId := len(openai_models) + id + 1
provider := "Google"
fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
}
for id, model := range ollama_models {
modelId := len(openai_models) + len(gemini_models) + id + 1
provider := "Ollama"
fmt.Printf("| %-*d | %-*s | %-*s |\n", idWidth, modelId, providerWidth, provider, nameWidth, model)
}
Expand Down Expand Up @@ -195,12 +201,19 @@ func PickModels(openai_models, ollama_models []string) string {
picked_models_map[id] = true
picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, openai_models[id-1])
}
} else if id > len(openai_models) && id <= len(ollama_models)+len(openai_models) {
} else if id > len(openai_models) && id <= len(gemini_models)+len(openai_models) {
// gemini model picked
if !picked_models_map[id] {
// if not already picked, add it to bin
picked_models_map[id] = true
picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, gemini_models[id-len(openai_models)-1])
}
} else if id > len(openai_models)+len(gemini_models) && id <= len(ollama_models)+len(gemini_models)+len(openai_models) {
// ollama model picked
if !picked_models_map[id] {
// if not already picked, add it to bin
picked_models_map[id] = true
picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(openai_models)-1])
picked_models_str = fmt.Sprintf("%s,%s", picked_models_str, ollama_models[id-len(gemini_models)-len(openai_models)-1])
}
} else {
// out of index, invalid
Expand All @@ -209,9 +222,8 @@ func PickModels(openai_models, ollama_models []string) string {
}
}
if len(invalid_selections) != 0 {
fmt.Printf("Skipping the invalid selections: %s \n", FormatMapKeys(invalid_selections))
fmt.Printf("Skipping the invalid selections: %s \n\n", FormatMapKeys(invalid_selections))
}
fmt.Printf("\n")
return picked_models_str
}

Expand Down
24 changes: 24 additions & 0 deletions utils/gemini.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package utils

import "strings"

// IsGeminiRequired checks if any of the picked models require Google Gemini by comparing them against a list of available Gemini models.
//
// Parameters:
// - picked_models: A comma-separated string of model names selected by the user.
// - gemini_models: A pointer to a slice of strings containing available Gemini model names.
//
// Returns:
// - bool: Returns true if any of the picked models match Gemini models, indicating that Gemini is required, otherwise false.
func IsGeminiRequired(picked_models string, gemini_models *[]string) bool {
required := false
for _, model := range strings.Split(picked_models, ",") {
for _, gemini_model := range *gemini_models {
if model == gemini_model {
required = true
break
}
}
}
return required
}

0 comments on commit e91e364

Please sign in to comment.