Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved configuration validation #278

Merged
merged 11 commits into from
Jan 20, 2025
40 changes: 34 additions & 6 deletions server/ai/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ const (
UserAccessLevelNone
)

const (
ServiceTypeOpenAI = "openai"
ServiceTypeOpenAICompatible = "openaicompatible"
ServiceTypeAzure = "azure"
ServiceTypeAskSage = "asksage"
ServiceTypeAnthropic = "anthropic"
)

type BotConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Expand All @@ -49,10 +57,30 @@ type BotConfig struct {
}

func (c *BotConfig) IsValid() bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to have this return an error with a specific failure. (I bet Aider could make this change pretty quick)

isInvalid := c.Name == "" ||
c.DisplayName == "" ||
c.Service.Type == "" ||
((c.Service.Type == "openaicompatible" || c.Service.Type == "azure") && c.Service.APIURL == "") ||
(c.Service.Type != "asksage" && c.Service.Type != "openaicompatible" && c.Service.Type != "azure" && c.Service.APIKey == "")
return !isInvalid
// Basic validation
if c.Name == "" || c.DisplayName == "" || c.Service.Type == "" {
return false
}

// Validate access levels are within bounds
if c.ChannelAccessLevel < ChannelAccessLevelAll || c.ChannelAccessLevel > ChannelAccessLevelNone {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is backwards. UserAccessLevelAll = 0 and ChannelAccessLevelNone = 3

Copy link
Contributor Author

@enzowritescode enzowritescode Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand your comment.

ChannelAccessLevelAll = 0 and ChannelAccessLevelNone = 3. The code checks that the value passed to it is not less than 0 and not greater than 3. Is that not correct?

return false
}
if c.UserAccessLevel < UserAccessLevelAll || c.UserAccessLevel > UserAccessLevelNone {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

@enzowritescode enzowritescode Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserAccessLevelAll = 0 and UserAccessLevelNone = 3. The code checks that the value passed to it is not less than 0 and not greater than 3. Is that not correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I read those wrong. Don't do code reviews while jet-legged kids 🛫

return false
}

// Service-specific validation
switch c.Service.Type {
case ServiceTypeOpenAI:
return c.Service.APIKey != "" && c.Service.OrgID != ""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OrgID is optional.

case ServiceTypeOpenAICompatible, ServiceTypeAzure:
return c.Service.APIKey != "" && c.Service.APIURL != ""
case ServiceTypeAnthropic:
return c.Service.APIKey != ""
case ServiceTypeAskSage:
return c.Service.Username != "" && c.Service.Password != ""
default:
return false
}
}
262 changes: 262 additions & 0 deletions server/ai/configuration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
package ai

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBotConfig_IsValid(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a success case. So a function that always fails will pass. (Which is the current case)

type fields struct {
ID string
Name string
DisplayName string
CustomInstructions string
Service ServiceConfig
EnableVision bool
DisableTools bool
ChannelAccessLevel ChannelAccessLevel
ChannelIDs []string
UserAccessLevel UserAccessLevel
UserIDs []string
TeamIDs []string
MaxFileSize int64
}
tests := []struct {
name string
fields fields
want bool
}{
{
name: "Invalid name",
fields: fields{
ID: "xxx",
Name: "", // bad
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid display name",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "", // bad
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid service type",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "mattermostllm", // bad
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid channel access level < ChannelAccessLevelAll",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll - 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Invalid channel access level > ChannelAccessLevelNone",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Invalid user access level < UserAccessLevelAll",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll - 1, // bad
},
want: false,
},
{
name: "Invalid user access level > UserAccessLevelNone",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelNone + 1, // bad
},
want: false,
},
{
name: "OpenAI compatible required API URL",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openaicompatible",
APIKey: "sk-xyz",
APIURL: "", // bad
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage requires username",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "", // bad
Password: "topsecret",
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage requires password",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "myuser",
Password: "", // bad
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &BotConfig{
ID: tt.fields.ID,
Name: tt.fields.Name,
DisplayName: tt.fields.DisplayName,
CustomInstructions: tt.fields.CustomInstructions,
Service: tt.fields.Service,
EnableVision: tt.fields.EnableVision,
DisableTools: tt.fields.DisableTools,
ChannelAccessLevel: tt.fields.ChannelAccessLevel,
ChannelIDs: tt.fields.ChannelIDs,
UserAccessLevel: tt.fields.UserAccessLevel,
UserIDs: tt.fields.UserIDs,
TeamIDs: tt.fields.TeamIDs,
MaxFileSize: tt.fields.MaxFileSize,
}
assert.Equalf(t, tt.want, c.IsValid(), "IsValid()")
})
}
}
6 changes: 3 additions & 3 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ func postsToChatCompletionMessages(posts []ai.Post) []openaiClient.ChatCompletio
return result
}

// createFunctionArrgmentResolver Creates a resolver for the json arguments of an openai function call. Unmarshaling the json into the supplied struct.
func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter {
// createFunctionArgumentResolver Creates a resolver for the json arguments of an openai function call. Unmarshalling the json into the supplied struct.
func createFunctionArgumentResolver(jsonArgs string) ai.ToolArgumentGetter {
return func(args any) error {
return json.Unmarshal([]byte(jsonArgs), args)
}
Expand Down Expand Up @@ -317,7 +317,7 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque
name := tool.Function.Name
arguments := tool.Function.Arguments
toolID := tool.ID
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArgumentResolver(arguments), conversation.Context)
if err != nil {
fmt.Printf("Error resolving function %s: %s", name, err)
}
Expand Down
10 changes: 5 additions & 5 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel {

var llm ai.LanguageModel
switch llmBotConfig.Service.Type {
case "openai":
case ai.ServiceTypeOpenAI:
llm = openai.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "openaicompatible":
case ai.ServiceTypeOpenAICompatible:
llm = openai.NewCompatible(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "azure":
case ai.ServiceTypeAzure:
llm = openai.NewAzure(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "anthropic":
case ai.ServiceTypeAnthropic:
llm = anthropic.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "asksage":
case ai.ServiceTypeAskSage:
llm = asksage.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
}

Expand Down
Loading