Skip to content

Commit

Permalink
fix: improve performance for checking for configured model providers
Browse files Browse the repository at this point in the history
This change simplifies the process for checking for configured model
providers. When listing model providers, instead of making a call to
reveal each credential, this change will make one call to list
credentials.

A side effect is that empty environment variables are not treated as
"configured" while there were not previously. Care is taken to remove
empty environment variables from the credential before saving it.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Dec 10, 2024
1 parent 11d536a commit 9ca1b5c
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 96 deletions.
18 changes: 17 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dev-open: ARGS=--open-uis
dev-open: dev

# Lint the project
lint: lint-admin
lint: lint-admin lint-go

lint-admin:
cd ui/admin && \
Expand All @@ -49,6 +49,22 @@ lint-admin:
package-tools:
./tools/package-tools.sh

tidy:
go mod tidy

GOLANGCI_LINT_VERSION ?= v1.62.2
setup-env:
if ! command -v golangci-lint &> /dev/null; then \
echo "Could not find golangci-lint, installing version $(GOLANGCI_LINT_VERSION)."; \
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin $(GOLANGCI_LINT_VERSION); \
fi

lint-go: setup-env
golangci-lint run

# Runs linters and validates that all generated code is committed.
validate-code: tidy lint no-changes

no-changes:
@if [ -n "$$(git status --porcelain)" ]; then \
git status --porcelain; \
Expand Down
13 changes: 6 additions & 7 deletions apiclient/types/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ type ToolReferenceManifest struct {
type ToolReference struct {
Metadata
ToolReferenceManifest
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credential string `json:"credential,omitempty"`
Params map[string]string `json:"params,omitempty"`
ModelProviderStatus *ModelProviderStatus `json:"modelProviderStatus,omitempty"`
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credential string `json:"credential,omitempty"`
Params map[string]string `json:"params,omitempty"`
}

type ToolReferenceList List[ToolReference]
5 changes: 0 additions & 5 deletions apiclient/types/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 35 additions & 9 deletions pkg/api/handlers/availablemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func NewAvailableModelsHandler(gClient *gptscript.GPTScript, dispatcher *dispatc
}

func (a *AvailableModelsHandler) List(req api.Context) error {
var modelProviderReference v1.ToolReferenceList
if err := req.List(&modelProviderReference, &kclient.ListOptions{
var modelProviderReferences v1.ToolReferenceList
if err := req.List(&modelProviderReferences, &kclient.ListOptions{
Namespace: req.Namespace(),
FieldSelector: fields.SelectorFromSet(map[string]string{
"spec.type": string(types.ToolReferenceTypeModelProvider),
Expand All @@ -39,11 +39,27 @@ func (a *AvailableModelsHandler) List(req api.Context) error {
return err
}

credCtxs := make([]string, 0, len(modelProviderReferences.Items))
for _, ref := range modelProviderReferences.Items {
credCtxs = append(credCtxs, string(ref.UID))
}

creds, err := a.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{
CredentialContexts: credCtxs,
})
if err != nil {
return fmt.Errorf("failed to list model provider credentials: %w", err)
}

credMap := make(map[string]map[string]string, len(creds))
for _, cred := range creds {
credMap[cred.Context+cred.ToolName] = cred.Env
}

var oModels openai.ModelsList
for _, modelProvider := range modelProviderReference.Items {
if convertedModelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProvider); err != nil {
return fmt.Errorf("failed to determine if model provider %q is configured: %w", modelProvider.Name, err)
} else if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool {
for _, modelProvider := range modelProviderReferences.Items {
convertedModelProvider := convertModelProviderToolRef(modelProvider, credMap[string(modelProvider.UID)+modelProvider.Name])
if !convertedModelProvider.Configured || modelProvider.Name == system.ModelProviderTool {
continue
}

Expand Down Expand Up @@ -74,9 +90,19 @@ func (a *AvailableModelsHandler) ListForModelProvider(req api.Context) error {
return types.NewErrBadRequest("%s is not a model provider", modelProviderReference.Name)
}

if modelProvider, err := convertModelProviderToolRef(req.Context(), a.gptscript, modelProviderReference); err != nil {
return fmt.Errorf("failed to determine if model provider is configured: %w", err)
} else if !modelProvider.Configured {
var credEnvVars map[string]string
if modelProviderReference.Status.Tool != nil {
if envVars := modelProviderReference.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := a.gptscript.RevealCredential(req.Context(), []string{string(modelProviderReference.UID)}, modelProviderReference.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", modelProviderReference.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

if modelProvider := convertModelProviderToolRef(modelProviderReference, credEnvVars); !modelProvider.Configured {
return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProviderReference.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", "))
}

Expand Down
102 changes: 64 additions & 38 deletions pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -39,12 +38,19 @@ func (mp *ModelProviderHandler) ByID(req api.Context) error {
)
}

modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return err
var credEnvVars map[string]string
if ref.Status.Tool != nil {
if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

return req.Write(modelProvider)
return req.Write(convertToolReferenceToModelProvider(ref, credEnvVars))
}

func (mp *ModelProviderHandler) List(req api.Context) error {
Expand All @@ -58,14 +64,26 @@ func (mp *ModelProviderHandler) List(req api.Context) error {
return err
}

resp := make([]types.ModelProvider, 0, len(refList.Items))
credCtxs := make([]string, 0, len(refList.Items))
for _, ref := range refList.Items {
modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return fmt.Errorf("failed to determine model provider status: %w", err)
}
credCtxs = append(credCtxs, string(ref.UID))
}

resp = append(resp, modelProvider)
creds, err := mp.gptscript.ListCredentials(req.Context(), gptscript.ListCredentialsOptions{
CredentialContexts: credCtxs,
})
if err != nil {
return fmt.Errorf("failed to list model provider credentials: %w", err)
}

credMap := make(map[string]map[string]string, len(creds))
for _, cred := range creds {
credMap[cred.Context+cred.ToolName] = cred.Env
}

resp := make([]types.ModelProvider, 0, len(refList.Items))
for _, ref := range refList.Items {
resp = append(resp, convertToolReferenceToModelProvider(ref, credMap[string(ref.UID)+ref.Name]))
}

return req.Write(types.ModelProviderList{Items: resp})
Expand All @@ -91,6 +109,12 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error {
return fmt.Errorf("failed to update credential: %w", err)
}

for key, val := range envVars {
if val == "" {
delete(envVars, key)
}
}

if err := mp.gptscript.CreateCredential(req.Context(), gptscript.Credential{
Context: string(ref.UID),
ToolName: ref.Name,
Expand Down Expand Up @@ -120,12 +144,18 @@ func (mp *ModelProviderHandler) Reveal(req api.Context) error {
return err
}

if ref.Spec.Type != types.ToolReferenceTypeModelProvider {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential: %w", err)
} else if err == nil {
return req.Write(cred.Env)
}

return req.Write(cred.Env)
return types.NewErrNotFound("no credential found for %q", ref.Name)
}

func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
Expand All @@ -138,11 +168,19 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

modelProvider, err := convertToolReferenceToModelProvider(req.Context(), mp.gptscript, ref)
if err != nil {
return err
var credEnvVars map[string]string
if ref.Status.Tool != nil {
if envVars := ref.Status.Tool.Metadata["envVars"]; envVars != "" {
cred, err := mp.gptscript.RevealCredential(req.Context(), []string{string(ref.UID)}, ref.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return fmt.Errorf("failed to reveal credential for model provider %q: %w", ref.Name, err)
} else if err == nil {
credEnvVars = cred.Env
}
}
}

modelProvider := convertToolReferenceToModelProvider(ref, credEnvVars)
if !modelProvider.Configured {
return types.NewErrBadRequest("model provider %s is not configured, missing configuration parameters: %s", modelProvider.ModelProviderManifest.Name, strings.Join(modelProvider.MissingConfigurationParameters, ", "))
}
Expand All @@ -156,19 +194,14 @@ func (mp *ModelProviderHandler) RefreshModels(req api.Context) error {
delete(ref.Annotations, v1.ModelProviderSyncAnnotation)
}

if err = req.Update(&ref); err != nil {
if err := req.Update(&ref); err != nil {
return fmt.Errorf("failed to sync models for model provider %q: %w", ref.Name, err)
}

return req.Write(modelProvider)
}

func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript.GPTScript, ref v1.ToolReference) (types.ModelProvider, error) {
status, err := convertModelProviderToolRef(ctx, gClient, ref)
if err != nil {
return types.ModelProvider{}, err
}

func convertToolReferenceToModelProvider(ref v1.ToolReference, credEnvVars map[string]string) types.ModelProvider {
name := ref.Name
if ref.Status.Tool != nil {
name = ref.Status.Tool.Name
Expand All @@ -180,34 +213,27 @@ func convertToolReferenceToModelProvider(ctx context.Context, gClient *gptscript
Name: name,
ToolReference: ref.Spec.Reference,
},
ModelProviderStatus: *status,
ModelProviderStatus: *convertModelProviderToolRef(ref, credEnvVars),
}

mp.Type = "modelprovider"

return mp, nil
return mp
}

func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTScript, toolRef v1.ToolReference) (*types.ModelProviderStatus, error) {
func convertModelProviderToolRef(toolRef v1.ToolReference, cred map[string]string) *types.ModelProviderStatus {
var (
requiredEnvVars, missingEnvVars []string
icon string
)
if toolRef.Status.Tool != nil {
if toolRef.Status.Tool.Metadata["envVars"] != "" {
cred, err := gptscript.RevealCredential(ctx, []string{string(toolRef.UID)}, toolRef.Name)
if err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
return nil, fmt.Errorf("failed to reveal credential for model provider %q: %w", toolRef.Name, err)
}

if toolRef.Status.Tool.Metadata["envVars"] != "" {
requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",")
}
requiredEnvVars = strings.Split(toolRef.Status.Tool.Metadata["envVars"], ",")
}

for _, envVar := range requiredEnvVars {
if cred.Env[envVar] == "" {
missingEnvVars = append(missingEnvVars, envVar)
}
for _, envVar := range requiredEnvVars {
if _, ok := cred[envVar]; !ok {
missingEnvVars = append(missingEnvVars, envVar)
}
}

Expand All @@ -227,5 +253,5 @@ func convertModelProviderToolRef(ctx context.Context, gptscript *gptscript.GPTSc
ModelsBackPopulated: modelsPopulated,
RequiredConfigurationParameters: requiredEnvVars,
MissingConfigurationParameters: missingEnvVars,
}, nil
}
}
Loading

0 comments on commit 9ca1b5c

Please sign in to comment.