From 2a8dbcddec780a72bc702c6306c191775aed085a Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Sun, 8 Oct 2023 17:38:31 +0800 Subject: [PATCH] feat: add dashscope as llm Signed-off-by: Abirdcfly --- README.md | 16 +- controllers/prompt_controller.go | 2 +- examples/dashscope/main.go | 79 ++++++++ pkg/llms/dashscope/api.go | 66 +++++++ pkg/llms/dashscope/http_client.go | 63 ++++++ pkg/llms/dashscope/params.go | 101 ++++++++++ pkg/llms/dashscope/response.go | 77 ++++++++ pkg/llms/dashscope/sse_client.go | 200 ++++++++++++++++++++ pkg/llms/dashscope/zz_generated.deepcopy.go | 76 ++++++++ pkg/llms/llms.go | 13 +- pkg/llms/openai/response.go | 2 +- pkg/llms/zhipuai/api.go | 2 +- pkg/llms/zhipuai/params.go | 4 +- pkg/llms/zhipuai/response.go | 2 +- 14 files changed, 683 insertions(+), 20 deletions(-) create mode 100644 examples/dashscope/main.go create mode 100644 pkg/llms/dashscope/api.go create mode 100644 pkg/llms/dashscope/http_client.go create mode 100644 pkg/llms/dashscope/params.go create mode 100644 pkg/llms/dashscope/response.go create mode 100644 pkg/llms/dashscope/sse_client.go create mode 100644 pkg/llms/dashscope/zz_generated.deepcopy.go diff --git a/README.md b/README.md index b3fa823d1..08ae5522a 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,13 @@ kubectl apply -f config/samples/arcadia_v1alpha1_llm.yaml kubectl apply -f config/samples/arcadia_v1alpha1_prompt.yaml ``` -After prompt got created, you can see the prompt in the following command: +After the prompt got created, you can see the prompt in the following command: ```shell kubectl get prompt prompt-zhipuai-sample -oyaml ``` -If no error found,you can use this command to get the prompt response data. +If no error is found, you can use this command to get the prompt response data. ```shell kubectl get prompt prompt-zhipuai-sample --output="jsonpath={.status.data}" | base64 --decode @@ -56,18 +56,18 @@ go install github.com/kubeagi/arcadia/arctl@latest ## Packages -To enhace the AI capability in Golang,we developed some packages. +To enhace the AI capability in Golang, we developed some packages. ### LLMs -- ✅ [ZhiPuAI(智谱AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) +- ✅ [ZhiPuAI(智谱 AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) - [example](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) ### Embeddings > Fully compatible with [langchain embeddings](https://github.com/tmc/langchaingo/tree/main/embeddings) -- ✅[ZhiPuAI(智谱AI) Embedding](https://github.com/kubeagi/arcadia/tree/main/pkg/embeddings/zhipuai) +- ✅[ZhiPuAI(智谱 AI) Embedding](https://github.com/kubeagi/arcadia/tree/main/pkg/embeddings/zhipuai) ### VectorStores @@ -77,10 +77,10 @@ To enhace the AI capability in Golang,we developed some packages. ## Examples -- [chat_with_document](https://github.com/kubeagi/arcadia/tree/main/examples/chat_with_document): a chat server which allows you chat with your document +- [chat_with_document](https://github.com/kubeagi/arcadia/tree/main/examples/chat_with_document): a chat server which allows you to chat with your document - [embedding](https://github.com/kubeagi/arcadia/tree/main/examples/embedding) shows how to embedes your document to vector store with embedding service -- [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows to to inquiry the security risks in your RBAC with AI. -- [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) show how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) +- [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows how to inquiry the security risks in your RBAC with AI. +- [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) shows how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) ## Contribute to Arcadia diff --git a/controllers/prompt_controller.go b/controllers/prompt_controller.go index a1aef83e2..15f8df808 100644 --- a/controllers/prompt_controller.go +++ b/controllers/prompt_controller.go @@ -100,7 +100,7 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom switch llm.Spec.Type { case llms.ZhiPuAI: llmClient = llmszhipuai.NewZhiPuAI(apiKey) - callData = prompt.Spec.ZhiPuAIParams.Marshall() + callData = prompt.Spec.ZhiPuAIParams.Marshal() case llms.OpenAI: llmClient = openai.NewOpenAI(apiKey) default: diff --git a/examples/dashscope/main.go b/examples/dashscope/main.go new file mode 100644 index 000000000..c60baa9d2 --- /dev/null +++ b/examples/dashscope/main.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "os" + + "github.com/kubeagi/arcadia/pkg/llms" + "github.com/kubeagi/arcadia/pkg/llms/dashscope" + "k8s.io/klog/v2" +) + +const ( + samplePrompt = "how to change a deployment's image?" +) + +func main() { + if len(os.Args) == 1 { + panic("api key is empty") + } + apiKey := os.Args[1] + klog.Infof("sample chat start...\nwe use same prompt: %s to test\n", samplePrompt) + for _, model := range []dashscope.Model{dashscope.QWEN14BChat, dashscope.QWEN7BChat} { + klog.V(0).Infof("\nChat with %s\n", model) + resp, err := sampleChat(apiKey, model) + if err != nil { + panic(err) + } + klog.V(0).Infof("Response: \n %s\n", resp) + klog.V(0).Infoln("\nChat again with sse enable") + err = sampleSSEChat(apiKey, model) + if err != nil { + panic(err) + } + } + klog.Infoln("sample chat done") +} + +func sampleChat(apiKey string, model dashscope.Model) (llms.Response, error) { + client := dashscope.NewDashScope(apiKey, false) + params := dashscope.DefaultModelParams() + params.Model = model + params.Input.Messages = []dashscope.Message{ + {Role: dashscope.System, Content: "You are a kubernetes expert."}, + {Role: dashscope.User, Content: samplePrompt}, + } + return client.Call(params.Marshal()) +} + +func sampleSSEChat(apiKey string, model dashscope.Model) error { + client := dashscope.NewDashScope(apiKey, true) + params := dashscope.DefaultModelParams() + params.Model = model + params.Input.Messages = []dashscope.Message{ + {Role: dashscope.System, Content: "You are a kubernetes expert."}, + {Role: dashscope.User, Content: samplePrompt}, + } + // you can define a customized `handler` on `Event` + err := client.StreamCall(context.TODO(), params.Marshal(), nil) + if err != nil { + return err + } + return nil +} diff --git a/pkg/llms/dashscope/api.go b/pkg/llms/dashscope/api.go new file mode 100644 index 000000000..2d1153370 --- /dev/null +++ b/pkg/llms/dashscope/api.go @@ -0,0 +1,66 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dashscope + +import ( + "context" + "errors" + + "github.com/kubeagi/arcadia/pkg/llms" +) + +const ( + DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" +) + +type Model string + +const ( + QWEN14BChat Model = "qwen-14b-chat" + QWEN7BChat Model = "qwen-7b-chat" +) + +var _ llms.LLM = (*DashScope)(nil) + +type DashScope struct { + apiKey string + sse bool +} + +func NewDashScope(apiKey string, sse bool) *DashScope { + return &DashScope{ + apiKey: apiKey, + sse: sse, + } +} + +func (z DashScope) Type() llms.LLMType { + return llms.DashScope +} + +// Call wraps a common AI api call +func (z *DashScope) Call(data []byte) (llms.Response, error) { + params := ModelParams{} + if err := params.Unmarshal(data); err != nil { + return nil, err + } + return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse) +} + +func (z *DashScope) Validate() (llms.Response, error) { + return nil, errors.New("not implemented") +} diff --git a/pkg/llms/dashscope/http_client.go b/pkg/llms/dashscope/http_client.go new file mode 100644 index 000000000..8485b731a --- /dev/null +++ b/pkg/llms/dashscope/http_client.go @@ -0,0 +1,63 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dashscope + +import ( + "bytes" + "context" + "encoding/json" + "net/http" +) + +func setHeaders(req *http.Request, token string, sse bool) { + if sse { + // req.Header.Set("Content-Type", "text/event-stream") // Although the documentation says we should do this, but will return a 400 error and the python sdk doesn't do this. + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-DashScope-SSE", "enable") + } else { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + } + req.Header.Set("Authorization", "Bearer "+token) +} + +func parseHTTPResponse(resp *http.Response) (data *Response, err error) { + if err = json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, err + } + return data, nil +} + +func req(ctx context.Context, apiURL, token string, data []byte, sse bool) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data)) + if err != nil { + return nil, err + } + + setHeaders(req, token, sse) + + return http.DefaultClient.Do(req) +} +func do(ctx context.Context, apiURL, token string, data []byte, sse bool) (*Response, error) { + resp, err := req(ctx, apiURL, token, data, sse) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return parseHTTPResponse(resp) +} diff --git a/pkg/llms/dashscope/params.go b/pkg/llms/dashscope/params.go new file mode 100644 index 000000000..92390ea3a --- /dev/null +++ b/pkg/llms/dashscope/params.go @@ -0,0 +1,101 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dashscope + +import ( + "encoding/json" + "errors" + + "github.com/kubeagi/arcadia/pkg/llms" +) + +type Role string + +const ( + System Role = "system" + User Role = "user" + Assistant Role = "assistant" +) + +var _ llms.ModelParams = (*ModelParams)(nil) + +// +kubebuilder:object:generate=true + +// ModelParams +// ref: https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-api-detailes#25745d61fbx49 +// do not use 'input.history', according to the above document, this parameter will be deprecated soon. +// use 'message' in 'parameters.result_format' to keep better compatibility. +type ModelParams struct { + Model Model `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters"` +} + +// +kubebuilder:object:generate=true + +type Input struct { + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float32 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed int `json:"seed,omitempty"` + ResultFormat string `json:"result_format,omitempty"` +} + +// +kubebuilder:object:generate=true + +type Message struct { + Role Role `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +func DefaultModelParams() ModelParams { + return ModelParams{ + Model: QWEN14BChat, + Input: Input{ + Messages: []Message{}, + }, + Parameters: Parameters{ + TopP: 0.5, + TopK: 0, + Seed: 1234, + ResultFormat: "message", + }, + } +} + +func (params *ModelParams) Marshal() []byte { + data, err := json.Marshal(params) + if err != nil { + return []byte{} + } + return data +} + +func (params *ModelParams) Unmarshal(bytes []byte) error { + return json.Unmarshal(bytes, params) +} + +func ValidateModelParams(params ModelParams) error { + if params.Parameters.TopP < 0 || params.Parameters.TopP > 1 { + return errors.New("top_p must be in (0, 1)") + } + + return nil +} diff --git a/pkg/llms/dashscope/response.go b/pkg/llms/dashscope/response.go new file mode 100644 index 000000000..dccc65b79 --- /dev/null +++ b/pkg/llms/dashscope/response.go @@ -0,0 +1,77 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dashscope + +import ( + "encoding/json" + + "github.com/kubeagi/arcadia/pkg/llms" +) + +var _ llms.Response = (*Response)(nil) + +type Response struct { + // https://help.aliyun.com/zh/dashscope/response-status-codes + StatusCode int `json:"status_code,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output Output `json:"output"` + Usage Usage `json:"usage"` + RequestID string `json:"request_id"` +} + +type Output struct { + Choices []Choice `json:"choices"` +} + +type FinishReason string + +const ( + Finish FinishReason = "stop" + Generating FinishReason = "null" + ToLoogin FinishReason = "length" +) + +type Choice struct { + FinishReason FinishReason `json:"finish_reason"` + Message Message `json:"message"` +} + +type Usage struct { + OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` +} + +func (response *Response) Unmarshal(bytes []byte) error { + return json.Unmarshal(bytes, response) +} + +func (response *Response) Type() llms.LLMType { + return llms.DashScope +} + +func (response *Response) Bytes() []byte { + bytes, err := json.Marshal(response) + if err != nil { + return []byte{} + } + return bytes +} + +func (response *Response) String() string { + return string(response.Bytes()) +} diff --git a/pkg/llms/dashscope/sse_client.go b/pkg/llms/dashscope/sse_client.go new file mode 100644 index 000000000..881272df1 --- /dev/null +++ b/pkg/llms/dashscope/sse_client.go @@ -0,0 +1,200 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// NOTE: Reference https://github.com/r3labs/sse/client.go + +package dashscope + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" + + "github.com/r3labs/sse/v2" +) + +func defaultHandler(event *sse.Event, last string) (newData string) { + switch string(event.Event) { + case "add", "error", "interrupted": + fmt.Printf("%s", event.Data) + case "result": + data := new(Response) + _ = json.Unmarshal(event.Data, data) + newData = data.Output.Choices[len(data.Output.Choices)-1].Message.Content + fmt.Printf("%s", strings.TrimPrefix(newData, last)) + return newData + default: + fmt.Printf("%s", event.Data) + } + return "" +} +func (z *DashScope) StreamCall(ctx context.Context, data []byte, handler func(event *sse.Event, last string) (data string)) error { + resp, err := req(ctx, DashScopeChatURL, z.apiKey, data, true) + if err != nil { + return err + } + defer resp.Body.Close() + // parse response body as stream events + eventChan, errorChan := NewSSEClient().Events(resp) + + // handle events + if handler == nil { + handler = defaultHandler + } + + last := "" + for { + select { + case err = <-errorChan: + return err + case msg := <-eventChan: + last = handler(msg, last) + } + } +} + +var ( + headerID = []byte("id:") + headerData = []byte("data:") + headerEvent = []byte("event:") + headerRetry = []byte("retry:") +) + +type SSEClient struct { + LastEventID atomic.Value // []byte + EncodingBase64 bool + + maxBufferSize int +} + +func NewSSEClient() *SSEClient { + return &SSEClient{ + maxBufferSize: 1 << 16, + } +} + +func (c *SSEClient) Events(resp *http.Response) (<-chan *sse.Event, <-chan error) { + reader := sse.NewEventStreamReader(resp.Body, c.maxBufferSize) + return c.startReadLoop(reader) +} + +func (c *SSEClient) startReadLoop(reader *sse.EventStreamReader) (chan *sse.Event, chan error) { + outCh := make(chan *sse.Event) + erChan := make(chan error) + go c.readLoop(reader, outCh, erChan) + return outCh, erChan +} + +func (c *SSEClient) readLoop(reader *sse.EventStreamReader, outCh chan *sse.Event, erChan chan error) { + for { + // Read each new line and process the type of event + event, err := reader.ReadEvent() + if err != nil { + if err == io.EOF { + erChan <- nil + return + } + erChan <- err + return + } + + // If we get an error, ignore it. + var msg *sse.Event + if msg, err = c.processEvent(event); err == nil { + if len(msg.ID) > 0 { + c.LastEventID.Store(msg.ID) + } else { + msg.ID, _ = c.LastEventID.Load().([]byte) + } + + // Send downstream if the event has something useful + if hasContent(msg) { + outCh <- msg + } + } + } +} + +func (c *SSEClient) processEvent(msg []byte) (event *sse.Event, err error) { + var e sse.Event + + if len(msg) < 1 { + return nil, errors.New("event message was empty") + } + + // Normalize the crlf to lf to make it easier to split the lines. + // Split the line by "\n" or "\r", per the spec. + for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { + switch { + case bytes.HasPrefix(line, headerID): + e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) + case bytes.HasPrefix(line, headerData): + // The spec allows for multiple data fields per event, concatenated them with "\n". + e.Data = append(e.Data, append(trimHeader(len(headerData), line), byte('\n'))...) + // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. + case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): + e.Data = append(e.Data, byte('\n')) + case bytes.HasPrefix(line, headerEvent): + e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) + case bytes.HasPrefix(line, headerRetry): + e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) + default: + // Ignore any garbage that doesn't match what we're looking for. + } + } + + // Trim the last "\n" per the spec. + e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) + + if c.EncodingBase64 { + buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data))) + + n, err := base64.StdEncoding.Decode(buf, e.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode event message: %s", err) + } + e.Data = buf[:n] + } + return &e, err +} + +func hasContent(e *sse.Event) bool { + return len(e.ID) > 0 || len(e.Data) > 0 || len(e.Event) > 0 || len(e.Retry) > 0 +} + +func trimHeader(size int, data []byte) []byte { + if data == nil || len(data) < size { + return data + } + + data = data[size:] + // Remove optional leading whitespace + if len(data) > 0 && data[0] == 32 { + data = data[1:] + } + // Remove trailing new line + if len(data) > 0 && data[len(data)-1] == 10 { + data = data[:len(data)-1] + } + return data +} diff --git a/pkg/llms/dashscope/zz_generated.deepcopy.go b/pkg/llms/dashscope/zz_generated.deepcopy.go new file mode 100644 index 000000000..ee48807bd --- /dev/null +++ b/pkg/llms/dashscope/zz_generated.deepcopy.go @@ -0,0 +1,76 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by controller-gen. DO NOT EDIT. + +package dashscope + +import () + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Input) DeepCopyInto(out *Input) { + *out = *in + if in.Messages != nil { + in, out := &in.Messages, &out.Messages + *out = make([]Message, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Input. +func (in *Input) DeepCopy() *Input { + if in == nil { + return nil + } + out := new(Input) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Message) DeepCopyInto(out *Message) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Message. +func (in *Message) DeepCopy() *Message { + if in == nil { + return nil + } + out := new(Message) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ModelParams) DeepCopyInto(out *ModelParams) { + *out = *in + in.Input.DeepCopyInto(&out.Input) + out.Parameters = in.Parameters +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelParams. +func (in *ModelParams) DeepCopy() *ModelParams { + if in == nil { + return nil + } + out := new(ModelParams) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/llms/llms.go b/pkg/llms/llms.go index 9eee1b1b9..e9a1d6a98 100644 --- a/pkg/llms/llms.go +++ b/pkg/llms/llms.go @@ -21,9 +21,10 @@ import "errors" type LLMType string const ( - OpenAI LLMType = "openai" - ZhiPuAI LLMType = "zhipuai" - Unknown LLMType = "unknown" + OpenAI LLMType = "openai" + ZhiPuAI LLMType = "zhipuai" + DashScope LLMType = "dashscope" + Unknown LLMType = "unknown" ) type LLM interface { @@ -33,15 +34,15 @@ type LLM interface { } type ModelParams interface { - Marshall() []byte - Unmarshall([]byte) error + Marshal() []byte + Unmarshal([]byte) error } type Response interface { Type() LLMType String() string Bytes() []byte - Unmarshall([]byte) error + Unmarshal([]byte) error } type UnknowLLM struct{} diff --git a/pkg/llms/openai/response.go b/pkg/llms/openai/response.go index d7080d827..c840ea9f3 100644 --- a/pkg/llms/openai/response.go +++ b/pkg/llms/openai/response.go @@ -47,7 +47,7 @@ func (response *Response) String() string { return string(response.Bytes()) } -func (response *Response) Unmarshall(bytes []byte) error { +func (response *Response) Unmarshal(bytes []byte) error { return json.Unmarshal(bytes, response) } diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go index 4026b5d65..c321dca21 100644 --- a/pkg/llms/zhipuai/api.go +++ b/pkg/llms/zhipuai/api.go @@ -77,7 +77,7 @@ func (z ZhiPuAI) Type() llms.LLMType { // Call wraps a common AI api call func (z *ZhiPuAI) Call(data []byte) (llms.Response, error) { params := ModelParams{} - if err := params.Unmarshall(data); err != nil { + if err := params.Unmarshal(data); err != nil { return nil, err } switch params.Method { diff --git a/pkg/llms/zhipuai/params.go b/pkg/llms/zhipuai/params.go index 3ee3c37b4..a108059e4 100644 --- a/pkg/llms/zhipuai/params.go +++ b/pkg/llms/zhipuai/params.go @@ -74,7 +74,7 @@ func DefaultModelParams() ModelParams { } } -func (params *ModelParams) Marshall() []byte { +func (params *ModelParams) Marshal() []byte { data, err := json.Marshal(params) if err != nil { return []byte{} @@ -82,7 +82,7 @@ func (params *ModelParams) Marshall() []byte { return data } -func (params *ModelParams) Unmarshall(bytes []byte) error { +func (params *ModelParams) Unmarshal(bytes []byte) error { return json.Unmarshal(bytes, params) } diff --git a/pkg/llms/zhipuai/response.go b/pkg/llms/zhipuai/response.go index 9be365707..4f0a811de 100644 --- a/pkg/llms/zhipuai/response.go +++ b/pkg/llms/zhipuai/response.go @@ -56,7 +56,7 @@ type Response struct { Success bool `json:"success"` } -func (response *Response) Unmarshall(bytes []byte) error { +func (response *Response) Unmarshal(bytes []byte) error { return json.Unmarshal(response.Bytes(), response) }