Skip to content

Commit

Permalink
Merge pull request #1201 from mattjoyce/feature/config-yaml
Browse files Browse the repository at this point in the history
feat: Add YAML configuration support
  • Loading branch information
eugeis authored Dec 14, 2024
2 parents aa2881f + 89153dd commit f180e8f
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 27 deletions.
184 changes: 162 additions & 22 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,40 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/jessevdk/go-flags"
goopenai "github.com/sashabaranov/go-openai"
"golang.org/x/text/language"
"gopkg.in/yaml.v2"

"github.com/danielmiessler/fabric/common"
)

// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
type Flags struct {
Pattern string `short:"p" long:"pattern" description:"Choose a pattern from the available patterns" default:""`
Pattern string `short:"p" long:"pattern" yaml:"pattern" description:"Choose a pattern from the available patterns" default:""`
PatternVariables map[string]string `short:"v" long:"variable" description:"Values for pattern variables, e.g. -v=#role:expert -v=#points:30"`
Context string `short:"C" long:"context" description:"Choose a context from the available contexts" default:""`
Session string `long:"session" description:"Choose a session from the available sessions"`
Attachments []string `short:"a" long:"attachment" description:"Attachment path or URL (e.g. for OpenAI image recognition messages)"`
Setup bool `short:"S" long:"setup" description:"Run setup for all reconfigurable parts of fabric"`
Temperature float64 `short:"t" long:"temperature" description:"Set temperature" default:"0.7"`
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
Raw bool `short:"r" long:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
Temperature float64 `short:"t" long:"temperature" yaml:"temperature" description:"Set temperature" default:"0.7"`
TopP float64 `short:"T" long:"topp" yaml:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" yaml:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" yaml:"presencepenalty" description:"Set presence penalty" default:"0.0"`
Raw bool `short:"r" long:"raw" yaml:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" yaml:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
ListAllContexts bool `short:"x" long:"listcontexts" description:"List all contexts"`
ListAllSessions bool `short:"X" long:"listsessions" description:"List all sessions"`
UpdatePatterns bool `short:"U" long:"updatepatterns" description:"Update patterns"`
Message string `hidden:"true" description:"Messages to send to chat"`
Copy bool `short:"c" long:"copy" description:"Copy to clipboard"`
Model string `short:"m" long:"model" description:"Choose model"`
ModelContextLength int `long:"modelContextLength" description:"Model context length (only affects ollama)"`
Model string `short:"m" long:"model" yaml:"model" description:"Choose model"`
ModelContextLength int `long:"modelContextLength" yaml:"modelContextLength" description:"Model context length (only affects ollama)"`
Output string `short:"o" long:"output" description:"Output to file" default:""`
OutputSession bool `long:"output-session" description:"Output the entire session (also a temporary one) to the output file"`
LatestPatterns string `short:"n" long:"latest" description:"Number of latest patterns to list" default:"0"`
Expand All @@ -49,7 +51,7 @@ type Flags struct {
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
Seed int `short:"e" long:"seed" yaml:"seed" description:"Seed to be used for LMM generation"`
WipeContext string `short:"w" long:"wipecontext" description:"Wipe context"`
WipeSession string `short:"W" long:"wipesession" description:"Wipe session"`
PrintContext string `long:"printcontext" description:"Print context"`
Expand All @@ -59,37 +61,175 @@ type Flags struct {
DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"`
Serve bool `long:"serve" description:"Serve the Fabric Rest API"`
ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"`
Config string `long:"config" description:"Path to YAML config file"`
Version bool `long:"version" description:"Print current version"`
}

var debug = false

func Debugf(format string, a ...interface{}) {
if debug {
fmt.Printf("DEBUG: "+format, a...)
}
}


// Init Initialize flags. returns a Flags struct and an error
func Init() (ret *Flags, err error) {
// Track which yaml-configured flags were set on CLI
usedFlags := make(map[string]bool)
args := os.Args[1:]

// Get list of fields that have yaml tags, could be in yaml config
yamlFields := make(map[string]bool)
t := reflect.TypeOf(Flags{})
for i := 0; i < t.NumField(); i++ {
if yamlTag := t.Field(i).Tag.Get("yaml"); yamlTag != "" {
yamlFields[yamlTag] = true
//Debugf("Found yaml-configured field: %s\n", yamlTag)
}
}

// Scan args for that are provided by cli and might be in yaml
for _, arg := range args {
if strings.HasPrefix(arg, "--") {
flag := strings.TrimPrefix(arg, "--")
if i := strings.Index(flag, "="); i > 0 {
flag = flag[:i]
}
if yamlFields[flag] {
usedFlags[flag] = true
Debugf("CLI flag used: %s\n", flag)
}
}
}

// Parse CLI flags first
ret = &Flags{}
parser := flags.NewParser(ret, flags.Default)
var args []string
if args, err = parser.Parse(); err != nil {
return
if _, err = parser.Parse(); err != nil {
return nil, err
}

// If config specified, load and apply YAML for unused flags
if ret.Config != "" {
yamlFlags, err := loadYAMLConfig(ret.Config)
if err != nil {
return nil, err
}

// Apply YAML values where CLI flags weren't used
flagsVal := reflect.ValueOf(ret).Elem()
yamlVal := reflect.ValueOf(yamlFlags).Elem()
flagsType := flagsVal.Type()

for i := 0; i < flagsType.NumField(); i++ {
field := flagsType.Field(i)
if yamlTag := field.Tag.Get("yaml"); yamlTag != "" {
if !usedFlags[yamlTag] {
flagField := flagsVal.Field(i)
yamlField := yamlVal.Field(i)
if flagField.CanSet() {
if yamlField.Type() != flagField.Type() {
if err := assignWithConversion(flagField, yamlField); err != nil {
Debugf("Type conversion failed for %s: %v\n", yamlTag, err)
continue
}
} else {
flagField.Set(yamlField)
}
Debugf("Applied YAML value for %s: %v\n", yamlTag, yamlField.Interface())
}
}
}
}
}

// Handle stdin and messages
info, _ := os.Stdin.Stat()
pipedToStdin := (info.Mode() & os.ModeCharDevice) == 0

//custom message
if len(args) > 0 {
ret.Message = AppendMessage(ret.Message, args[len(args)-1])
ret.Message = AppendMessage(ret.Message, args[len(args)-1])
}

// takes input from stdin if it exists, otherwise takes input from args (the last argument)
if pipedToStdin {
var pipedMessage string
if pipedMessage, err = readStdin(); err != nil {
return
}
ret.Message = AppendMessage(ret.Message, pipedMessage)
var pipedMessage string
if pipedMessage, err = readStdin(); err != nil {
return
}
ret.Message = AppendMessage(ret.Message, pipedMessage)
}
return

return ret, nil
}



func assignWithConversion(targetField, sourceField reflect.Value) error {
switch targetField.Kind() {
case reflect.Float64:
if sourceField.Kind() == reflect.Int || sourceField.Kind() == reflect.Float32 {
targetField.SetFloat(float64(sourceField.Convert(reflect.TypeOf(float64(0))).Float()))
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
case reflect.Int:
if sourceField.Kind() == reflect.Float64 || sourceField.Kind() == reflect.Float32 {
targetField.SetInt(int64(sourceField.Convert(reflect.TypeOf(int64(0))).Int()))
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
case reflect.String:
if sourceField.Kind() == reflect.Interface {
if str, ok := sourceField.Interface().(string); ok {
targetField.SetString(str)
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
}
case reflect.Bool:
if sourceField.Kind() == reflect.Interface {
if b, ok := sourceField.Interface().(bool); ok {
targetField.SetBool(b)
Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface())
return nil
}
}
}

return fmt.Errorf("unsupported conversion: %s to %s", sourceField.Type(), targetField.Type())
}




func loadYAMLConfig(configPath string) (*Flags, error) {
absPath, err := common.GetAbsolutePath(configPath)
if err != nil {
return nil, fmt.Errorf("invalid config path: %w", err)
}

data, err := os.ReadFile(absPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("config file not found: %s", absPath)
}
return nil, fmt.Errorf("error reading config file: %w", err)
}

// Use the existing Flags struct for YAML unmarshal
config := &Flags{}
if err := yaml.Unmarshal(data, config); err != nil {
return nil, fmt.Errorf("error parsing config file: %w", err)
}

Debugf("Config: %v\n", config)

return config, nil
}


// readStdin reads from stdin and returns the input as a string or an error
func readStdin() (ret string, err error) {
reader := bufio.NewReader(os.Stdin)
Expand Down
77 changes: 77 additions & 0 deletions cli/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,80 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}

func TestInitWithYAMLConfig(t *testing.T) {
// Create a temporary YAML config file
configContent := `
temperature: 0.9
model: gpt-4
pattern: analyze
stream: true
`
tmpfile, err := os.CreateTemp("", "config.*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

if _, err := tmpfile.Write([]byte(configContent)); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}

// Test 1: Basic YAML loading
t.Run("Load YAML config", func(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", tmpfile.Name()}

flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, 0.9, flags.Temperature)
assert.Equal(t, "gpt-4", flags.Model)
assert.Equal(t, "analyze", flags.Pattern)
assert.True(t, flags.Stream)
})

// Test 2: CLI overrides YAML
t.Run("CLI overrides YAML", func(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", tmpfile.Name(), "--temperature", "0.7", "--model", "gpt-3.5-turbo"}

flags, err := Init()
assert.NoError(t, err)
assert.Equal(t, 0.7, flags.Temperature)
assert.Equal(t, "gpt-3.5-turbo", flags.Model)
assert.Equal(t, "analyze", flags.Pattern) // unchanged from YAML
assert.True(t, flags.Stream) // unchanged from YAML
})

// Test 3: Invalid YAML config
t.Run("Invalid YAML config", func(t *testing.T) {
badConfig := `
temperature: "not a float"
model: 123 # should be string
`
badfile, err := os.CreateTemp("", "bad-config.*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(badfile.Name())

if _, err := badfile.Write([]byte(badConfig)); err != nil {
t.Fatal(err)
}
if err := badfile.Close(); err != nil {
t.Fatal(err)
}

oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--config", badfile.Name()}

_, err = Init()
assert.Error(t, err)
})
}
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ toolchain go1.23.1

require (
github.com/anaskhan96/soup v1.2.5
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4
github.com/atotto/clipboard v0.1.4
github.com/gabriel-vasile/mimetype v1.4.6
github.com/gin-gonic/gin v1.10.0
github.com/go-git/go-git/v5 v5.12.0
github.com/go-shiori/go-readability v0.0.0-20241012063810-92284fa8a71f
github.com/google/generative-ai-go v0.18.0
github.com/jessevdk/go-flags v1.6.1
github.com/joho/godotenv v1.5.1
github.com/liushuangls/go-anthropic/v2 v2.11.0
github.com/ollama/ollama v0.4.1
github.com/otiai10/copy v1.14.0
github.com/pkg/errors v0.9.1
Expand All @@ -22,6 +23,7 @@ require (
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.20.0
google.golang.org/api v0.205.0
gopkg.in/yaml.v2 v2.4.0
)

require (
Expand All @@ -35,7 +37,6 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.1.2 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 // indirect
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
github.com/bytedance/sonic v1.12.4 // indirect
github.com/bytedance/sonic/loader v0.2.1 // indirect
Expand All @@ -58,7 +59,6 @@ require (
github.com/goccy/go-json v0.10.3 // indirect
github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/generative-ai-go v0.18.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
Expand Down
3 changes: 1 addition & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/liushuangls/go-anthropic/v2 v2.11.0 h1:YKyxDWQNaKPPgtLCgBH+JqzuznNWw8ZqQVeSdQNDMds=
github.com/liushuangls/go-anthropic/v2 v2.11.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
Expand Down Expand Up @@ -360,6 +358,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down

0 comments on commit f180e8f

Please sign in to comment.