Skip to content

Commit

Permalink
Reject invalid tool values produced by the LLM (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
crspeller authored Jan 3, 2024
1 parent e3223a5 commit b626aaf
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions server/built_in_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"

"github.com/google/go-github/v41/github"
"github.com/mattermost/mattermost-plugin-ai/server/ai"
Expand All @@ -23,6 +24,10 @@ func (p *Plugin) toolResolveLookupMattermostUser(context ai.ConversationContext,
return "", errors.Wrap(err, "failed to get arguments for tool LookupMattermostUser")
}

if !model.IsValidUsername(args.Username) {
return "invalid username", errors.New("invalid username")
}

// Fail for guests.
if !p.pluginAPI.User.HasPermissionTo(context.RequestingUser.Id, model.PermissionViewMembers) {
return "", errors.New("user doesn't have permission to lookup users")
Expand Down Expand Up @@ -75,6 +80,14 @@ func (p *Plugin) toolResolveGetChannelPosts(context ai.ConversationContext, args
return "invalid parameters to function", errors.Wrap(err, "failed to get arguments for tool GetChannelPosts")
}

if !model.IsValidChannelIdentifier(args.ChannelName) {
return "invalid channel name", errors.New("invalid channel name")
}

if args.NumberPosts < 1 || args.NumberPosts > 100 {
return "invalid number of posts. only 100 supported at a time", errors.New("invalid number of posts")
}

if context.Channel == nil || context.Channel.TeamId == "" {
//TODO: support DMs. This will require some way to disabiguate between channels with the same name on different teams.
return "Error: Ambiguous channel lookup. Unable to what channel the user is reffering to because DMs do not belong to specific teams. Tell the user to ask outside a DM channel.", errors.New("ambiguous channel lookup")
Expand Down Expand Up @@ -116,13 +129,30 @@ func formatIssue(issue *github.Issue) string {
return fmt.Sprintf("Title: %s\nNumber: %d\nState: %s\nSubmitter: %s\nIs Pull Request: %v\nBody: %s", issue.GetTitle(), issue.GetNumber(), issue.GetState(), issue.User.GetLogin(), issue.IsPullRequest(), issue.GetBody())
}

var validGithubRepoName = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)

func (p *Plugin) toolGetGithubIssue(context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args GetGithubIssueArgs
err := argsGetter(&args)
if err != nil {
return "invalid parameters to function", errors.Wrap(err, "failed to get arguments for tool GetGithubIssues")
}

// Fail for over lengh repo ownder or name.
if len(args.RepoOwner) > 39 || len(args.RepoName) > 100 {
return "invalid parameters to function", errors.New("invalid repo owner or repo name")
}

// Fail if repo ownder or repo name contain invalid characters.
if !validGithubRepoName.MatchString(args.RepoOwner) || !validGithubRepoName.MatchString(args.RepoName) {
return "invalid parameters to function", errors.New("invalid repo owner or repo name")
}

// Fail for bad issue numbers.
if args.Number < 1 {
return "invalid parameters to function", errors.New("invalid issue number")
}

req, err := http.NewRequest("GET", fmt.Sprintf("/github/api/v1/issue?owner=%s&repo=%s&number=%d", args.RepoOwner, args.RepoName, args.Number), nil)
if err != nil {
return "internal failure", errors.Wrap(err, "failed to create request")
Expand Down

0 comments on commit b626aaf

Please sign in to comment.