diff --git a/server/api_post.go b/server/api_post.go index 5e9b2518..8be8c648 100644 --- a/server/api_post.go +++ b/server/api_post.go @@ -126,6 +126,16 @@ func (p *Plugin) handleRegenerate(c *gin.Context) { post := c.MustGet(ContextPostKey).(*model.Post) channel := c.MustGet(ContextChannelKey).(*model.Channel) + if post.UserId != p.botid { + c.AbortWithError(http.StatusBadRequest, errors.New("Not a AI bot post")) + return + } + + if post.GetProp("llm_requester_user_id") != userID { + c.AbortWithError(http.StatusForbidden, errors.New("only the original poster can regenerate")) + return + } + user, err := p.pluginAPI.User.Get(userID) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) @@ -141,11 +151,6 @@ func (p *Plugin) handleRegenerate(c *gin.Context) { postToRegenerate := threadData.latestPost() - if user.Id != postToRegenerate.UserId { - c.AbortWithError(http.StatusForbidden, errors.New("only the original poster can regenerate")) - return - } - context := p.MakeConversationContext(user, channel, postToRegenerate) conversation, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, context) if err != nil { diff --git a/server/service.go b/server/service.go index 7e6704cb..13eb05f9 100644 --- a/server/service.go +++ b/server/service.go @@ -81,7 +81,7 @@ func (p *Plugin) continueConversation(context ai.ConversationContext) error { // Special handing for threads started by the bot in response to a summarization request. var result *ai.TextStreamResult originalThreadID, ok := threadData.Posts[0].GetProp(ThreadIDProp).(string) - if ok && originalThreadID != "" { + if ok && originalThreadID != "" && threadData.Posts[0].UserId == p.botid { threadPost, err := p.pluginAPI.Post.GetPost(originalThreadID) if err != nil { return err