-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathstream_reader.go
128 lines (110 loc) · 2.69 KB
/
stream_reader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package go_ernie
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
utils "github.com/anhao/go-ernie/internal"
)
var (
headerData = []byte("data: ")
errorPrefix = []byte(`{"error_code":`)
)
type streamable interface {
ErnieBotResponse |
ErnieBotTurboResponse |
Bloomz7b1Response |
LlamaChatResponse |
BaiduChatResponse |
ErnieCustomPluginResponse |
ErnieBot4Response |
CompletionResponse |
ErnieBot8KResponse |
ErnieBotTurboAIResponse
}
type streamReader[T streamable] struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler
}
func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
func (stream *streamReader[T]) processLines() (T, error) {
var (
emptyMessagesCount uint
hasErrorPrefix bool
apiError APIError
)
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %s", respErr.Error())
}
noSpaceLine := bytes.TrimSpace(rawLine)
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
unmarshaler := utils.JSONUnmarshaler{}
err := unmarshaler.Unmarshal(noSpaceLine, &apiError)
if err != nil {
return *new(T), err
}
return *new(T), &apiError
}
return *new(T), readErr
}
noSpaceLine := bytes.TrimSpace(rawLine)
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
hasErrorPrefix = true
}
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
if hasErrorPrefix {
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
}
continue
}
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if len(noPrefixLine) == 0 {
stream.isFinished = true
return *new(T), io.EOF
}
var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
}
}
func (stream *streamReader[T]) unmarshalError() (errResp *APIError) {
errBytes := stream.errAccumulator.Bytes()
if len(errBytes) == 0 {
return
}
err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
if err != nil {
errResp = nil
}
return
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}