Skip to content

Commit

Permalink
Add tests, remove hard coded strings in favor of constants, typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
enzowritescode committed Jan 16, 2025
1 parent 0035043 commit 836386d
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 17 deletions.
20 changes: 16 additions & 4 deletions server/ai/configuration.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ai

import "C"

type ServiceConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Expand Down Expand Up @@ -32,6 +34,14 @@ const (
UserAccessLevelNone
)

const (
ServiceTypeOpenAI = "openai"
ServiceTypeOpenAICompatible = "openaicompatible"
ServiceTypeAzure = "azure"
ServiceTypeAskSage = "asksage"
ServiceTypeAnthropic = "anthropic"
)

type BotConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Expand Down Expand Up @@ -64,11 +74,13 @@ func (c *BotConfig) IsValid() bool {

// Service-specific validation
switch c.Service.Type {
case "openai", "anthropic":
case ServiceTypeOpenAI:
return c.Service.APIKey != "" && c.Service.OrgID != ""
case ServiceTypeOpenAICompatible, ServiceTypeAzure:
return c.Service.APIKey != "" && c.Service.APIURL != ""
case ServiceTypeAnthropic:
return c.Service.APIKey != ""
case "openaicompatible", "azure":
return c.Service.APIURL != ""
case "asksage":
case ServiceTypeAskSage:
return c.Service.Username != "" && c.Service.Password != ""
default:
return false
Expand Down
261 changes: 261 additions & 0 deletions server/ai/configuration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package ai

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestBotConfig_IsValid(t *testing.T) {
type fields struct {
ID string
Name string
DisplayName string
CustomInstructions string
Service ServiceConfig
EnableVision bool
DisableTools bool
ChannelAccessLevel ChannelAccessLevel
ChannelIDs []string
UserAccessLevel UserAccessLevel
UserIDs []string
TeamIDs []string
MaxFileSize int64
}
tests := []struct {
name string
fields fields
want bool
}{
{
name: "Invalid name",
fields: fields{
ID: "xxx",
Name: "", // bad
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid display name",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "", // bad
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid service type",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "mattermostllm",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll,
},
want: false,
},
{
name: "Invalid channel access level < ChannelAccessLevelAll",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll - 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Invalid channel access level > ChannelAccessLevelNone",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1, // bad
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Invalid user access level < UserAccessLevelAll",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelAll - 1, // bad
},
want: false,
},
{
name: "Invalid user access level > UserAccessLevelNone",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openai",
APIKey: "sk-xyz",
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelAll,
UserAccessLevel: UserAccessLevelNone + 1, // bad
},
want: false,
},
{
name: "OpenAI compatible required API URL",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "openaicompatible",
APIKey: "sk-xyz",
APIURL: "", // bad
OrgID: "org-xyz",
DefaultModel: "gpt-40",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage requires username",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "", // bad
Password: "topsecret",
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
{
name: "Ask Sage requires password",
fields: fields{
ID: "xxx",
Name: "xxx",
DisplayName: "xxx",
CustomInstructions: "",
Service: ServiceConfig{
Name: "Copilot",
Type: "asksage",
Username: "myuser",
Password: "", // bad
DefaultModel: "xxx",
TokenLimit: 100,
StreamingTimeoutSeconds: 60,
},
ChannelAccessLevel: ChannelAccessLevelNone + 1,
UserAccessLevel: UserAccessLevelNone,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &BotConfig{
ID: tt.fields.ID,
Name: tt.fields.Name,
DisplayName: tt.fields.DisplayName,
CustomInstructions: tt.fields.CustomInstructions,
Service: tt.fields.Service,
EnableVision: tt.fields.EnableVision,
DisableTools: tt.fields.DisableTools,
ChannelAccessLevel: tt.fields.ChannelAccessLevel,
ChannelIDs: tt.fields.ChannelIDs,
UserAccessLevel: tt.fields.UserAccessLevel,
UserIDs: tt.fields.UserIDs,
TeamIDs: tt.fields.TeamIDs,
MaxFileSize: tt.fields.MaxFileSize,
}
assert.Equalf(t, tt.want, c.IsValid(), "IsValid()")
})
}
}
16 changes: 8 additions & 8 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ type OpenAI struct {
sendUserID bool
}

const StreamingTimeoutDefault = 10 * time.Second

const MaxFunctionCalls = 10

const OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB
const (
StreamingTimeoutDefault = 10 * time.Second
MaxFunctionCalls = 10
OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB
)

var ErrStreamingTimeout = errors.New("timeout streaming")

Expand Down Expand Up @@ -192,8 +192,8 @@ func postsToChatCompletionMessages(posts []ai.Post) []openaiClient.ChatCompletio
return result
}

// createFunctionArrgmentResolver Creates a resolver for the json arguments of an openai function call. Unmarshaling the json into the supplied struct.
func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter {
// createFunctionArgumentResolver Creates a resolver for the json arguments of an openai function call. Unmarshalling the json into the supplied struct.
func createFunctionArgumentResolver(jsonArgs string) ai.ToolArgumentGetter {
return func(args any) error {
return json.Unmarshal([]byte(jsonArgs), args)
}
Expand Down Expand Up @@ -317,7 +317,7 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque
name := tool.Function.Name
arguments := tool.Function.Arguments
toolID := tool.ID
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArgumentResolver(arguments), conversation.Context)
if err != nil {
fmt.Printf("Error resolving function %s: %s", name, err)
}
Expand Down
10 changes: 5 additions & 5 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel {

var llm ai.LanguageModel
switch llmBotConfig.Service.Type {
case "openai":
case ai.ServiceTypeOpenAI:
llm = openai.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "openaicompatible":
case ai.ServiceTypeOpenAICompatible:
llm = openai.NewCompatible(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "azure":
case ai.ServiceTypeAzure:
llm = openai.NewAzure(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "anthropic":
case ai.ServiceTypeAnthropic:
llm = anthropic.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
case "asksage":
case ai.ServiceTypeAskSage:
llm = asksage.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics)
}

Expand Down

0 comments on commit 836386d

Please sign in to comment.