-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved configuration validation #278
Changes from 7 commits
0035043
836386d
c8ad674
1634948
31aa957
0827270
f15302e
4a0eef0
c08a40a
4f0ce22
a00251e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,14 @@ const ( | |
UserAccessLevelNone | ||
) | ||
|
||
const ( | ||
ServiceTypeOpenAI = "openai" | ||
ServiceTypeOpenAICompatible = "openaicompatible" | ||
ServiceTypeAzure = "azure" | ||
ServiceTypeAskSage = "asksage" | ||
ServiceTypeAnthropic = "anthropic" | ||
) | ||
|
||
type BotConfig struct { | ||
ID string `json:"id"` | ||
Name string `json:"name"` | ||
|
@@ -49,10 +57,30 @@ type BotConfig struct { | |
} | ||
|
||
func (c *BotConfig) IsValid() bool { | ||
isInvalid := c.Name == "" || | ||
c.DisplayName == "" || | ||
c.Service.Type == "" || | ||
((c.Service.Type == "openaicompatible" || c.Service.Type == "azure") && c.Service.APIURL == "") || | ||
(c.Service.Type != "asksage" && c.Service.Type != "openaicompatible" && c.Service.Type != "azure" && c.Service.APIKey == "") | ||
return !isInvalid | ||
// Basic validation | ||
if c.Name == "" || c.DisplayName == "" || c.Service.Type == "" { | ||
return false | ||
} | ||
|
||
// Validate access levels are within bounds | ||
if c.ChannelAccessLevel < ChannelAccessLevelAll || c.ChannelAccessLevel > ChannelAccessLevelNone { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is backwards. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand your comment.
|
||
return false | ||
} | ||
if c.UserAccessLevel < UserAccessLevelAll || c.UserAccessLevel > UserAccessLevelNone { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I read those wrong. Don't do code reviews while jet-legged kids 🛫 |
||
return false | ||
} | ||
|
||
// Service-specific validation | ||
switch c.Service.Type { | ||
case ServiceTypeOpenAI: | ||
return c.Service.APIKey != "" && c.Service.OrgID != "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The OrgID is optional. |
||
case ServiceTypeOpenAICompatible, ServiceTypeAzure: | ||
return c.Service.APIKey != "" && c.Service.APIURL != "" | ||
case ServiceTypeAnthropic: | ||
return c.Service.APIKey != "" | ||
case ServiceTypeAskSage: | ||
return c.Service.Username != "" && c.Service.Password != "" | ||
default: | ||
return false | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
package ai | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestBotConfig_IsValid(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There isn't a success case. So a function that always fails will pass. (Which is the current case) |
||
type fields struct { | ||
ID string | ||
Name string | ||
DisplayName string | ||
CustomInstructions string | ||
Service ServiceConfig | ||
EnableVision bool | ||
DisableTools bool | ||
ChannelAccessLevel ChannelAccessLevel | ||
ChannelIDs []string | ||
UserAccessLevel UserAccessLevel | ||
UserIDs []string | ||
TeamIDs []string | ||
MaxFileSize int64 | ||
} | ||
tests := []struct { | ||
name string | ||
fields fields | ||
want bool | ||
}{ | ||
{ | ||
name: "Invalid name", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "", // bad | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll, | ||
UserAccessLevel: UserAccessLevelAll, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid display name", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "", // bad | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll, | ||
UserAccessLevel: UserAccessLevelAll, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid service type", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "mattermostllm", // bad | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll, | ||
UserAccessLevel: UserAccessLevelAll, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid channel access level < ChannelAccessLevelAll", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll - 1, // bad | ||
UserAccessLevel: UserAccessLevelNone, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid channel access level > ChannelAccessLevelNone", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelNone + 1, // bad | ||
UserAccessLevel: UserAccessLevelNone, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid user access level < UserAccessLevelAll", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll, | ||
UserAccessLevel: UserAccessLevelAll - 1, // bad | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Invalid user access level > UserAccessLevelNone", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openai", | ||
APIKey: "sk-xyz", | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelAll, | ||
UserAccessLevel: UserAccessLevelNone + 1, // bad | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "OpenAI compatible required API URL", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "openaicompatible", | ||
APIKey: "sk-xyz", | ||
APIURL: "", // bad | ||
OrgID: "org-xyz", | ||
DefaultModel: "gpt-40", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelNone + 1, | ||
UserAccessLevel: UserAccessLevelNone, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Ask Sage requires username", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "asksage", | ||
Username: "", // bad | ||
Password: "topsecret", | ||
DefaultModel: "xxx", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelNone + 1, | ||
UserAccessLevel: UserAccessLevelNone, | ||
}, | ||
want: false, | ||
}, | ||
{ | ||
name: "Ask Sage requires password", | ||
fields: fields{ | ||
ID: "xxx", | ||
Name: "xxx", | ||
DisplayName: "xxx", | ||
CustomInstructions: "", | ||
Service: ServiceConfig{ | ||
Name: "Copilot", | ||
Type: "asksage", | ||
Username: "myuser", | ||
Password: "", // bad | ||
DefaultModel: "xxx", | ||
TokenLimit: 100, | ||
StreamingTimeoutSeconds: 60, | ||
}, | ||
ChannelAccessLevel: ChannelAccessLevelNone + 1, | ||
UserAccessLevel: UserAccessLevelNone, | ||
}, | ||
want: false, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
c := &BotConfig{ | ||
ID: tt.fields.ID, | ||
Name: tt.fields.Name, | ||
DisplayName: tt.fields.DisplayName, | ||
CustomInstructions: tt.fields.CustomInstructions, | ||
Service: tt.fields.Service, | ||
EnableVision: tt.fields.EnableVision, | ||
DisableTools: tt.fields.DisableTools, | ||
ChannelAccessLevel: tt.fields.ChannelAccessLevel, | ||
ChannelIDs: tt.fields.ChannelIDs, | ||
UserAccessLevel: tt.fields.UserAccessLevel, | ||
UserIDs: tt.fields.UserIDs, | ||
TeamIDs: tt.fields.TeamIDs, | ||
MaxFileSize: tt.fields.MaxFileSize, | ||
} | ||
assert.Equalf(t, tt.want, c.IsValid(), "IsValid()") | ||
}) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to have this return an error with a specific failure. (I bet Aider could make this change pretty quick)