Skip to content

Commit

Permalink
Add StreamingStdOutHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 27, 2023
1 parent 859a262 commit c4c7e94
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 4 deletions.
33 changes: 33 additions & 0 deletions callback/streaming_stdout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package callback

import (
"context"
"fmt"
"io"
"os"

"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure StreamingStdOutHandler satisfies the Callback interface.
var _ schema.Callback = (*StreamingStdOutHandler)(nil)

type StreamingStdOutHandler struct {
handler
writer io.Writer
}

func NewStreamingStdOutHandler() *StreamingStdOutHandler {
return &StreamingStdOutHandler{
writer: os.Stdout,
}
}

func (cb *StreamingStdOutHandler) AlwaysVerbose() bool {
return true
}

func (cb *StreamingStdOutHandler) OnModelNewToken(ctx context.Context, input *schema.ModelNewTokenInput) error {
fmt.Fprint(cb.writer, input.Token)
return nil
}
38 changes: 38 additions & 0 deletions docs/content/en/docs/llms_and_prompts/models/llms/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,42 @@ openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
// Error handling
}
```

## Streaming
```go
package main

import (
"context"
"log"
"os"

"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

func main() {
openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"), func(o *llm.OpenAIOptions) {
o.Stream = true
})
if err != nil {
log.Fatal(err)
}

_, mErr := model.GeneratePrompt(context.Background(), openai, prompt.StringPromptValue("Write me a song about sparkling water."), func(o *model.Options) {
o.Callbacks = []schema.Callback{callback.NewStreamingStdOutHandler()}
})
if mErr != nil {
log.Fatal(mErr)
}
}
```
Output:
```text
Verse 1:
There's a little sparkle...
```
29 changes: 29 additions & 0 deletions examples/llm_streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package main

import (
"context"
"log"
"os"

"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

func main() {
openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"), func(o *llm.OpenAIOptions) {
o.Stream = true
})
if err != nil {
log.Fatal(err)
}

_, mErr := model.GeneratePrompt(context.Background(), openai, prompt.StringPromptValue("Write me a song about sparkling water."), func(o *model.Options) {
o.Callbacks = []schema.Callback{callback.NewStreamingStdOutHandler()}
})
if mErr != nil {
log.Fatal(mErr)
}
}
9 changes: 8 additions & 1 deletion model/llm/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"io"
"strings"

"github.com/avast/retry-go"
"github.com/hupe1980/golc"
Expand Down Expand Up @@ -159,6 +160,8 @@ func (l *OpenAI) Generate(ctx context.Context, prompt string, optFns ...func(o *

defer stream.Close()

tokens := []string{}

streamProcessing:
for {
select {
Expand All @@ -180,9 +183,13 @@ func (l *OpenAI) Generate(ctx context.Context, prompt string, optFns ...func(o *
return nil, err
}

choices = append(choices, res.Choices...)
tokens = append(tokens, res.Choices[0].Text)
}
}

choices = append(choices, openai.CompletionChoice{
Text: strings.Join(tokens, ""),
})
} else {
res, err := l.createCompletionWithRetry(ctx, completionRequest)
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ func LLMGenerate(ctx context.Context, model schema.LLM, prompt string, optFns ..
}

cm := callback.NewManager(opts.Callbacks, model.Callbacks(), model.Verbose(), func(mo *callback.ManagerOptions) {
if opts.ParentRunID != "" {
mo.ParentRunID = opts.ParentRunID
}
mo.ParentRunID = opts.ParentRunID
})

rm, err := cm.OnLLMStart(ctx, &schema.LLMStartManagerInput{
Expand Down

0 comments on commit c4c7e94

Please sign in to comment.