Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 22, 2023
1 parent ae727dd commit 33f405e
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 5 deletions.
9 changes: 9 additions & 0 deletions chain/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Current conversation:
Human: {{.input}}
AI:`

// Compile time check to ensure Conversation satisfies the Chain interface.
var _ schema.Chain = (*Conversation)(nil)

type ConversationOptions struct {
*schema.CallbackOptions
Prompt *prompt.Template
Expand Down Expand Up @@ -59,6 +62,8 @@ func NewConversation(llm schema.LLM, optFns ...func(o *ConversationOptions)) (*C
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *Conversation) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
promptValue, err := c.opts.Prompt.FormatPrompt(inputs)
if err != nil {
Expand All @@ -81,18 +86,22 @@ func (c *Conversation) Prompt() *prompt.Template {
return c.opts.Prompt
}

// Memory returns the memory associated with the chain.
func (c *Conversation) Memory() schema.Memory {
return c.opts.Memory
}

// Type returns the type of the chain.
func (c *Conversation) Type() string {
return "Conversation"
}

// Verbose returns the verbosity setting of the chain.
func (c *Conversation) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *Conversation) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
14 changes: 14 additions & 0 deletions chain/conversational_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,29 @@ Chat History:
Follow Up Input: {{.query}}
Standalone question:`

// Compile time check to ensure ConversationalRetrieval satisfies the Chain interface.
var _ schema.Chain = (*ConversationalRetrieval)(nil)

// ConversationalRetrievalOptions represents the options for the ConversationalRetrieval chain.
type ConversationalRetrievalOptions struct {
*schema.CallbackOptions
ReturnSourceDocuments bool
ReturnGeneratedQuestion bool
CondenseQuestionPrompt *prompt.Template
StuffQAPrompt *prompt.Template
Memory schema.Memory
InputKey string
OutputKey string
}

// ConversationalRetrieval is a chain implementation for conversational retrieval.
type ConversationalRetrieval struct {
condenseQuestionChain *LLM
retrievalQAChain *RetrievalQA
opts ConversationalRetrievalOptions
}

// NewConversationalRetrieval creates a new instance of the ConversationalRetrieval chain.
func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *ConversationalRetrievalOptions)) (*ConversationalRetrieval, error) {
opts := ConversationalRetrievalOptions{
CallbackOptions: &schema.CallbackOptions{
Expand Down Expand Up @@ -69,6 +76,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
}

retrievalQAChain, err := NewRetrievalQA(llm, retriever, func(o *RetrievalQAOptions) {
o.StuffQAPrompt = opts.StuffQAPrompt
o.ReturnSourceDocuments = opts.ReturnSourceDocuments
o.InputKey = opts.InputKey
})
Expand All @@ -83,6 +91,8 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
output, err := golc.Call(ctx, c.condenseQuestionChain, inputs)
if err != nil {
Expand Down Expand Up @@ -121,18 +131,22 @@ func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainVa
return returns, nil
}

// Memory returns the memory associated with the chain.
func (c ConversationalRetrieval) Memory() schema.Memory {
return c.opts.Memory
}

// Type returns the type of the chain.
func (c ConversationalRetrieval) Type() string {
return "ConversationalRetrieval"
}

// Verbose returns the verbosity setting of the chain.
func (c ConversationalRetrieval) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c ConversationalRetrieval) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
9 changes: 9 additions & 0 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure LLM satisfies the Chain interface.
var _ schema.Chain = (*LLM)(nil)

type LLMOptions struct {
*schema.CallbackOptions
Memory schema.Memory
Expand Down Expand Up @@ -42,6 +45,8 @@ func NewLLM(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMOption
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *LLM) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
promptValue, err := c.prompt.FormatPrompt(inputs)
if err != nil {
Expand All @@ -64,18 +69,22 @@ func (c *LLM) Prompt() *prompt.Template {
return c.prompt
}

// Memory returns the memory associated with the chain.
func (c *LLM) Memory() schema.Memory {
return c.opts.Memory
}

// Type returns the type of the chain.
func (c *LLM) Type() string {
return "LLM"
}

// Verbose returns the verbosity setting of the chain.
func (c *LLM) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *LLM) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
9 changes: 9 additions & 0 deletions chain/llm_bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ That is the format. Begin!
Question: {{.question}}`

// Compile time check to ensure LLMBash satisfies the Chain interface.
var _ schema.Chain = (*LLMBash)(nil)

type LLMBashOptions struct {
*schema.CallbackOptions
InputKey string
Expand Down Expand Up @@ -78,6 +81,8 @@ func NewLLMBash(llm schema.LLM, optFns ...func(o *LLMBashOptions)) (*LLMBash, er
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *LLMBash) Call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
Expand Down Expand Up @@ -114,18 +119,22 @@ func (c *LLMBash) Call(ctx context.Context, values schema.ChainValues) (schema.C
}, nil
}

// Memory returns the memory associated with the chain.
func (c *LLMBash) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *LLMBash) Type() string {
return "LLMBash"
}

// Verbose returns the verbosity setting of the chain.
func (c *LLMBash) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated chain.
func (c *LLMBash) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
9 changes: 9 additions & 0 deletions chain/llm_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ Question: 37593^(1/5)
Question: {{.question}}
`

// Compile time check to ensure LLMMath satisfies the Chain interface.
var _ schema.Chain = (*LLMMath)(nil)

type LLMMathOptions struct {
*schema.CallbackOptions
InputKey string
Expand Down Expand Up @@ -78,6 +81,8 @@ func NewLLMMath(llm schema.LLM, optFns ...func(o *LLMMathOptions)) (*LLMMath, er
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *LLMMath) Call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
Expand Down Expand Up @@ -127,18 +132,22 @@ func (c *LLMMath) evaluateExpression(expression string) (string, error) {
return fmt.Sprintf("%f", output), nil
}

// Memory returns the memory associated with the chain.
func (c *LLMMath) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *LLMMath) Type() string {
return "LLMMath"
}

// Verbose returns the verbosity setting of the chain.
func (c *LLMMath) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *LLMMath) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
6 changes: 6 additions & 0 deletions chain/refine_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ func NewRefineDocuments(llmChain *LLM, refineLLMChain *LLM, optFns ...func(o *Re
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *RefineDocuments) Call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
Expand Down Expand Up @@ -101,18 +103,22 @@ func (c *RefineDocuments) Call(ctx context.Context, values schema.ChainValues) (
}, nil
}

// Memory returns the memory associated with the chain.
func (c *RefineDocuments) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *RefineDocuments) Type() string {
return "RefineDocuments"
}

// Verbose returns the verbosity setting of the chain.
func (c *RefineDocuments) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *RefineDocuments) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
19 changes: 15 additions & 4 deletions chain/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

type RetrievalQAOptions struct {
*schema.CallbackOptions
StuffQAPrompt *prompt.Template
InputKey string
ReturnSourceDocuments bool
}
Expand All @@ -34,12 +35,16 @@ func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o
fn(&opts)
}

stuffPrompt, err := prompt.NewTemplate(defaultStuffQAPromptTemplate)
if err != nil {
return nil, err
if opts.StuffQAPrompt == nil {
p, err := prompt.NewTemplate(defaultStuffQAPromptTemplate)
if err != nil {
return nil, err
}

opts.StuffQAPrompt = p
}

llmChain, err := NewLLM(llm, stuffPrompt)
llmChain, err := NewLLM(llm, opts.StuffQAPrompt)
if err != nil {
return nil, err
}
Expand All @@ -56,6 +61,8 @@ func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *RetrievalQA) Call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
Expand Down Expand Up @@ -87,18 +94,22 @@ func (c *RetrievalQA) Call(ctx context.Context, values schema.ChainValues) (sche
return result, nil
}

// Memory returns the memory associated with the chain.
func (c *RetrievalQA) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *RetrievalQA) Type() string {
return "RetrievalQA"
}

// Verbose returns the verbosity setting of the chain.
func (c *RetrievalQA) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *RetrievalQA) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
6 changes: 6 additions & 0 deletions chain/stuff_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ func NewStuffDocuments(llmChain *LLM, optFns ...func(o *StuffDocumentsOptions))
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *StuffDocuments) Call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
Expand All @@ -64,18 +66,22 @@ func (c *StuffDocuments) Call(ctx context.Context, values schema.ChainValues) (s
return golc.Call(ctx, c.llmChain, inputValues)
}

// Memory returns the memory associated with the chain.
func (c *StuffDocuments) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *StuffDocuments) Type() string {
return "StuffDocuments"
}

// Verbose returns the verbosity setting of the chain.
func (c *StuffDocuments) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *StuffDocuments) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
6 changes: 6 additions & 0 deletions chain/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,28 @@ func NewTransform(inputKeys, outputKeys []string, transform TransformFunc, optFn
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *Transform) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
return c.transform(inputs)
}

// Memory returns the memory associated with the chain.
func (c *Transform) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *Transform) Type() string {
return "Transform"
}

// Verbose returns the verbosity setting of the chain.
func (c *Transform) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *Transform) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}
Expand Down
Loading

0 comments on commit 33f405e

Please sign in to comment.