Skip to content

Commit

Permalink
Improved configuration validation (#278)
Browse files Browse the repository at this point in the history
* Improved configuration validation from aider feedback plus added additional validation

* Add tests, remove hard coded strings in favor of constants, typo fixes

* Remove accidental import

* Undo change

* Fix linting issue

* add comment to test

* Remove optional value from validation

* Add valid test cases

* Improve tests

---------

Co-authored-by: Christopher Speller <[email protected]>
  • Loading branch information
enzowritescode and crspeller authored Jan 20, 2025
1 parent 0e1814d commit 06c7e57
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 14 deletions.
40 changes: 34 additions & 6 deletions server/ai/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ const (
UserAccessLevelNone
)

const (
ServiceTypeOpenAI = "openai"
ServiceTypeOpenAICompatible = "openaicompatible"
ServiceTypeAzure = "azure"
ServiceTypeAskSage = "asksage"
ServiceTypeAnthropic = "anthropic"
)

type BotConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Expand All @@ -49,10 +57,30 @@ type BotConfig struct {
}

func (c *BotConfig) IsValid() bool {
isInvalid := c.Name == "" ||
c.DisplayName == "" ||
c.Service.Type == "" ||
((c.Service.Type == "openaicompatible" || c.Service.Type == "azure") && c.Service.APIURL == "") ||
(c.Service.Type != "asksage" && c.Service.Type != "openaicompatible" && c.Service.Type != "azure" && c.Service.APIKey == "")
return !isInvalid
// Basic validation
if c.Name == "" || c.DisplayName == "" || c.Service.Type == "" {
return false
}

// Validate access levels are within bounds
if c.ChannelAccessLevel < ChannelAccessLevelAll || c.ChannelAccessLevel > ChannelAccessLevelNone {
return false
}
if c.UserAccessLevel < UserAccessLevelAll || c.UserAccessLevel > UserAccessLevelNone {
return false
}

// Service-specific validation
switch c.Service.Type {
case ServiceTypeOpenAI:
return c.Service.APIKey != ""
case ServiceTypeOpenAICompatible, ServiceTypeAzure:
return c.Service.APIKey != "" && c.Service.APIURL != ""
case ServiceTypeAnthropic:
return c.Service.APIKey != ""
case ServiceTypeAskSage:
return c.Service.Username != "" && c.Service.Password != ""
default:
return false
}
}
304 changes: 304 additions & 0 deletions server/ai/configuration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
package ai

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBotConfig_IsValid(t *testing.T) {
type fields struct {
ID string
Name string
DisplayName string
CustomInstructions string
Service ServiceConfig
EnableVision bool
DisableTools bool
ChannelAccessLevel ChannelAccessLevel
ChannelIDs []string
UserAccessLevel UserAccessLevel
UserIDs []string
TeamIDs []string
MaxFileSize int64
}
tests := []struct {
name string
fields fields
want bool
}{
{
name: "Valid OpenAI configuration with minimal required fields",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: true,
},
{
name: "Valid OpenAI configuration with ChannelAccessLevelNone",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone,
UserAccessLevel: UserAccessLevelAll,
},
want: true,
},
{
name: "Bot name cannot be empty",
fields: fields{
ID: "xxx",
Name: "", // bad
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Bot display name cannot be empty",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "", // bad
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Service type must be one of the supported providers",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "mattermostllm", // bad
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Channel access level cannot be less than ChannelAccessLevelAll (0)",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll - 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Channel access level cannot be greater than ChannelAccessLevelNone (3)",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "User access level cannot be less than UserAccessLevelAll (0)",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll - 1, // bad
},
want: false,
},
{
name: "User access level cannot be greater than UserAccessLevelNone (3)",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelNone + 1, // bad
},
want: false,
},
{
name: "OpenAI compatible service requires API URL to be set",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openaicompatible",
APIKey: "sk-xyz",
APIURL: "", // bad
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage service requires username to be set",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "", // bad
Password: "topsecret",
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage service requires password to be set",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "myuser",
Password: "", // bad
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &BotConfig{
ID: tt.fields.ID,
Name: tt.fields.Name,
DisplayName: tt.fields.DisplayName,
CustomInstructions: tt.fields.CustomInstructions,
Service: tt.fields.Service,
EnableVision: tt.fields.EnableVision,
DisableTools: tt.fields.DisableTools,
ChannelAccessLevel: tt.fields.ChannelAccessLevel,
ChannelIDs: tt.fields.ChannelIDs,
UserAccessLevel: tt.fields.UserAccessLevel,
UserIDs: tt.fields.UserIDs,
TeamIDs: tt.fields.TeamIDs,
MaxFileSize: tt.fields.MaxFileSize,
}
assert.Equalf(t, tt.want, c.IsValid(), "IsValid() for test case %q", tt.name)
})
}
}
6 changes: 3 additions & 3 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ func postsToChatCompletionMessages(posts []ai.Post) []openaiClient.ChatCompletio
return result
}

// createFunctionArrgmentResolver Creates a resolver for the json arguments of an openai function call. Unmarshaling the json into the supplied struct.
func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter {
// createFunctionArgumentResolver Creates a resolver for the json arguments of an openai function call. Unmarshalling the json into the supplied struct.
func createFunctionArgumentResolver(jsonArgs string) ai.ToolArgumentGetter {
return func(args any) error {
return json.Unmarshal([]byte(jsonArgs), args)
}
Expand Down Expand Up @@ -317,7 +317,7 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque
name := tool.Function.Name
arguments := tool.Function.Arguments
toolID := tool.ID
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArgumentResolver(arguments), conversation.Context)
if err != nil {
fmt.Printf("Error resolving function %s: %s", name, err)
}
Expand Down
10 changes: 5 additions & 5 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel {

var llm ai.LanguageModel
switch llmBotConfig.Service.Type {
case "openai":
case ai.ServiceTypeOpenAI:
llm = openai.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "openaicompatible":
case ai.ServiceTypeOpenAICompatible:
llm = openai.NewCompatible(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "azure":
case ai.ServiceTypeAzure:
llm = openai.NewAzure(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "anthropic":
case ai.ServiceTypeAnthropic:
llm = anthropic.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "asksage":
case ai.ServiceTypeAskSage:
llm = asksage.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
}

Expand Down

0 comments on commit 06c7e57

Please sign in to comment.