Skip to content

Commit

Permalink
Improve ai plugins (alibaba#1657)
Browse files Browse the repository at this point in the history
Co-authored-by: Kent Dong <[email protected]>
  • Loading branch information
rinfx and CH3CHO authored Jan 9, 2025
1 parent 2a89c3b commit ea0d5e7
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 34 deletions.
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-cache/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {

c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
}
c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}

if json.Get("enableSemanticCache").Exists() {
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-prompt-template/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.

func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
if templateEnable != "true" {
if templateEnable == "false" {
ctx.DontReadRequestBody()
return types.ActionContinue
}
Expand Down
9 changes: 5 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"math/rand"
"net/http"
"strings"
"time"


"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
Expand Down Expand Up @@ -551,7 +551,8 @@ func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.Ht
}

func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
return token
}

func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
Expand Down
26 changes: 7 additions & 19 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ const (
LowRisk = "low"
NoRisk = "none"

OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`

DefaultRequestCheckService = "llm_query_moderation"
Expand Down Expand Up @@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
log.Debugf("checking request body...")
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 {
log.Info("request content is empty. skip")
Expand Down Expand Up @@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
Expand Down Expand Up @@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
return types.ActionPause
}

func convertHeaders(hs [][2]string) map[string][]string {
ret := make(map[string][]string)
for _, h := range hs {
k, v := strings.ToLower(h[0]), h[1]
ret[k] = append(ret[k], v)
}
return ret
}

func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse {
log.Debugf("response checking is disabled")
Expand All @@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
Expand Down Expand Up @@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.incrementCounter("ai_sec_response_deny", 1)
Expand Down
24 changes: 16 additions & 8 deletions plugins/wasm-go/extensions/ai-statistics/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
RouteName = "route"
ClusterName = "cluster"
APIName = "api"
ConsumerKey = "x-mse-consumer"

// Source Type
FixedValue = "fixed_value"
Expand Down Expand Up @@ -81,8 +82,8 @@ type AIStatisticsConfig struct {
shouldBufferStreamingBody bool
}

func generateMetricName(route, cluster, model, metricName string) string {
return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName)
func generateMetricName(route, cluster, model, consumer, metricName string) string {
return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
}

func getRouteName() (string, error) {
Expand Down Expand Up @@ -115,6 +116,9 @@ func getClusterName() (string, error) {
}

func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
if inc == 0 {
return
}
counter, ok := config.counterMetrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
Expand Down Expand Up @@ -158,6 +162,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
ctx.SetContext(ClusterName, cluster)
ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
ctx.SetContext(ConsumerKey, consumer)
}

// Set user defined log & span attributes which type is fixed_value
setAttributeBySource(ctx, config, FixedValue, nil, log)
Expand Down Expand Up @@ -388,6 +395,7 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
var ok bool
var route, cluster, model string
var inputToken, outputToken uint64
consumer := ctx.GetStringContext(ConsumerKey, "none")
route, ok = ctx.GetContext(RouteName).(string)
if !ok {
log.Warnf("RouteName typd assert failed, skip metric record")
Expand Down Expand Up @@ -421,8 +429,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)

// Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration uint64
Expand All @@ -433,17 +441,17 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
log.Warnf("LLMFirstTokenDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
}
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
if !ok {
log.Warnf("LLMServiceDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
}
}

Expand Down

0 comments on commit ea0d5e7

Please sign in to comment.