From 89153dd235687c87571a8f55fa93659fded19c5f Mon Sep 17 00:00:00 2001 From: Matt Joyce Date: Sat, 14 Dec 2024 14:37:12 +1100 Subject: [PATCH] feat: Add YAML configuration support Add support for persistent configuration via YAML files. Users can now specify common options in a config file while maintaining the ability to override with CLI flags. Currently supports core options like model, temperature, and pattern settings. - Add --config flag for specifying YAML config path - Support standard option precedence (CLI > YAML > defaults) - Add type-safe YAML parsing with reflection - Add tests for YAML config functionality --- cli/flags.go | 184 ++++++++++++++++++++++++++++++++++++++++------ cli/flags_test.go | 77 +++++++++++++++++++ go.mod | 6 +- go.sum | 3 +- 4 files changed, 243 insertions(+), 27 deletions(-) diff --git a/cli/flags.go b/cli/flags.go index abb9a111b..38eb7b83f 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -6,29 +6,31 @@ import ( "fmt" "io" "os" + "reflect" "strings" "github.com/jessevdk/go-flags" goopenai "github.com/sashabaranov/go-openai" "golang.org/x/text/language" + "gopkg.in/yaml.v2" "github.com/danielmiessler/fabric/common" ) // Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli type Flags struct { - Pattern string `short:"p" long:"pattern" description:"Choose a pattern from the available patterns" default:""` + Pattern string `short:"p" long:"pattern" yaml:"pattern" description:"Choose a pattern from the available patterns" default:""` PatternVariables map[string]string `short:"v" long:"variable" description:"Values for pattern variables, e.g. -v=#role:expert -v=#points:30"` Context string `short:"C" long:"context" description:"Choose a context from the available contexts" default:""` Session string `long:"session" description:"Choose a session from the available sessions"` Attachments []string `short:"a" long:"attachment" description:"Attachment path or URL (e.g. for OpenAI image recognition messages)"` Setup bool `short:"S" long:"setup" description:"Run setup for all reconfigurable parts of fabric"` - Temperature float64 `short:"t" long:"temperature" description:"Set temperature" default:"0.7"` - TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"` - Stream bool `short:"s" long:"stream" description:"Stream"` - PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"` - Raw bool `short:"r" long:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."` - FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"` + Temperature float64 `short:"t" long:"temperature" yaml:"temperature" description:"Set temperature" default:"0.7"` + TopP float64 `short:"T" long:"topp" yaml:"topp" description:"Set top P" default:"0.9"` + Stream bool `short:"s" long:"stream" yaml:"stream" description:"Stream"` + PresencePenalty float64 `short:"P" long:"presencepenalty" yaml:"presencepenalty" description:"Set presence penalty" default:"0.0"` + Raw bool `short:"r" long:"raw" yaml:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."` + FrequencyPenalty float64 `short:"F" long:"frequencypenalty" yaml:"frequencypenalty" description:"Set frequency penalty" default:"0.0"` ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"` ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"` ListAllContexts bool `short:"x" long:"listcontexts" description:"List all contexts"` @@ -36,8 +38,8 @@ type Flags struct { UpdatePatterns bool `short:"U" long:"updatepatterns" description:"Update patterns"` Message string `hidden:"true" description:"Messages to send to chat"` Copy bool `short:"c" long:"copy" description:"Copy to clipboard"` - Model string `short:"m" long:"model" description:"Choose model"` - ModelContextLength int `long:"modelContextLength" description:"Model context length (only affects ollama)"` + Model string `short:"m" long:"model" yaml:"model" description:"Choose model"` + ModelContextLength int `long:"modelContextLength" yaml:"modelContextLength" description:"Model context length (only affects ollama)"` Output string `short:"o" long:"output" description:"Output to file" default:""` OutputSession bool `long:"output-session" description:"Output the entire session (also a temporary one) to the output file"` LatestPatterns string `short:"n" long:"latest" description:"Number of latest patterns to list" default:"0"` @@ -49,7 +51,7 @@ type Flags struct { Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""` ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"` ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"` - Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"` + Seed int `short:"e" long:"seed" yaml:"seed" description:"Seed to be used for LMM generation"` WipeContext string `short:"w" long:"wipecontext" description:"Wipe context"` WipeSession string `short:"W" long:"wipesession" description:"Wipe session"` PrintContext string `long:"printcontext" description:"Print context"` @@ -59,37 +61,175 @@ type Flags struct { DryRun bool `long:"dry-run" description:"Show what would be sent to the model without actually sending it"` Serve bool `long:"serve" description:"Serve the Fabric Rest API"` ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"` + Config string `long:"config" description:"Path to YAML config file"` Version bool `long:"version" description:"Print current version"` } +var debug = false + +func Debugf(format string, a ...interface{}) { + if debug { + fmt.Printf("DEBUG: "+format, a...) + } +} + + // Init Initialize flags. returns a Flags struct and an error func Init() (ret *Flags, err error) { + // Track which yaml-configured flags were set on CLI + usedFlags := make(map[string]bool) + args := os.Args[1:] + + // Get list of fields that have yaml tags, could be in yaml config + yamlFields := make(map[string]bool) + t := reflect.TypeOf(Flags{}) + for i := 0; i < t.NumField(); i++ { + if yamlTag := t.Field(i).Tag.Get("yaml"); yamlTag != "" { + yamlFields[yamlTag] = true + //Debugf("Found yaml-configured field: %s\n", yamlTag) + } + } + + // Scan args for that are provided by cli and might be in yaml + for _, arg := range args { + if strings.HasPrefix(arg, "--") { + flag := strings.TrimPrefix(arg, "--") + if i := strings.Index(flag, "="); i > 0 { + flag = flag[:i] + } + if yamlFields[flag] { + usedFlags[flag] = true + Debugf("CLI flag used: %s\n", flag) + } + } + } + + // Parse CLI flags first ret = &Flags{} parser := flags.NewParser(ret, flags.Default) - var args []string - if args, err = parser.Parse(); err != nil { - return + if _, err = parser.Parse(); err != nil { + return nil, err } + // If config specified, load and apply YAML for unused flags + if ret.Config != "" { + yamlFlags, err := loadYAMLConfig(ret.Config) + if err != nil { + return nil, err + } + + // Apply YAML values where CLI flags weren't used + flagsVal := reflect.ValueOf(ret).Elem() + yamlVal := reflect.ValueOf(yamlFlags).Elem() + flagsType := flagsVal.Type() + + for i := 0; i < flagsType.NumField(); i++ { + field := flagsType.Field(i) + if yamlTag := field.Tag.Get("yaml"); yamlTag != "" { + if !usedFlags[yamlTag] { + flagField := flagsVal.Field(i) + yamlField := yamlVal.Field(i) + if flagField.CanSet() { + if yamlField.Type() != flagField.Type() { + if err := assignWithConversion(flagField, yamlField); err != nil { + Debugf("Type conversion failed for %s: %v\n", yamlTag, err) + continue + } + } else { + flagField.Set(yamlField) + } + Debugf("Applied YAML value for %s: %v\n", yamlTag, yamlField.Interface()) + } + } + } + } + } + + // Handle stdin and messages info, _ := os.Stdin.Stat() pipedToStdin := (info.Mode() & os.ModeCharDevice) == 0 - //custom message if len(args) > 0 { - ret.Message = AppendMessage(ret.Message, args[len(args)-1]) + ret.Message = AppendMessage(ret.Message, args[len(args)-1]) } - // takes input from stdin if it exists, otherwise takes input from args (the last argument) if pipedToStdin { - var pipedMessage string - if pipedMessage, err = readStdin(); err != nil { - return - } - ret.Message = AppendMessage(ret.Message, pipedMessage) + var pipedMessage string + if pipedMessage, err = readStdin(); err != nil { + return + } + ret.Message = AppendMessage(ret.Message, pipedMessage) } - return + + return ret, nil } + + +func assignWithConversion(targetField, sourceField reflect.Value) error { + switch targetField.Kind() { + case reflect.Float64: + if sourceField.Kind() == reflect.Int || sourceField.Kind() == reflect.Float32 { + targetField.SetFloat(float64(sourceField.Convert(reflect.TypeOf(float64(0))).Float())) + Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface()) + return nil + } + case reflect.Int: + if sourceField.Kind() == reflect.Float64 || sourceField.Kind() == reflect.Float32 { + targetField.SetInt(int64(sourceField.Convert(reflect.TypeOf(int64(0))).Int())) + Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface()) + return nil + } + case reflect.String: + if sourceField.Kind() == reflect.Interface { + if str, ok := sourceField.Interface().(string); ok { + targetField.SetString(str) + Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface()) + return nil + } + } + case reflect.Bool: + if sourceField.Kind() == reflect.Interface { + if b, ok := sourceField.Interface().(bool); ok { + targetField.SetBool(b) + Debugf("Converted field %s : %v\n", targetField.Type(), targetField.Interface()) + return nil + } + } + } + + return fmt.Errorf("unsupported conversion: %s to %s", sourceField.Type(), targetField.Type()) +} + + + + +func loadYAMLConfig(configPath string) (*Flags, error) { + absPath, err := common.GetAbsolutePath(configPath) + if err != nil { + return nil, fmt.Errorf("invalid config path: %w", err) + } + + data, err := os.ReadFile(absPath) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("config file not found: %s", absPath) + } + return nil, fmt.Errorf("error reading config file: %w", err) + } + + // Use the existing Flags struct for YAML unmarshal + config := &Flags{} + if err := yaml.Unmarshal(data, config); err != nil { + return nil, fmt.Errorf("error parsing config file: %w", err) + } + + Debugf("Config: %v\n", config) + + return config, nil +} + + // readStdin reads from stdin and returns the input as a string or an error func readStdin() (ret string, err error) { reader := bufio.NewReader(os.Stdin) diff --git a/cli/flags_test.go b/cli/flags_test.go index 4167bb448..197d2f9ed 100644 --- a/cli/flags_test.go +++ b/cli/flags_test.go @@ -87,3 +87,80 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) { options := flags.BuildChatOptions() assert.Equal(t, expectedOptions, options) } + +func TestInitWithYAMLConfig(t *testing.T) { + // Create a temporary YAML config file + configContent := ` +temperature: 0.9 +model: gpt-4 +pattern: analyze +stream: true +` + tmpfile, err := os.CreateTemp("", "config.*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write([]byte(configContent)); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + // Test 1: Basic YAML loading + t.Run("Load YAML config", func(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = []string{"cmd", "--config", tmpfile.Name()} + + flags, err := Init() + assert.NoError(t, err) + assert.Equal(t, 0.9, flags.Temperature) + assert.Equal(t, "gpt-4", flags.Model) + assert.Equal(t, "analyze", flags.Pattern) + assert.True(t, flags.Stream) + }) + + // Test 2: CLI overrides YAML + t.Run("CLI overrides YAML", func(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = []string{"cmd", "--config", tmpfile.Name(), "--temperature", "0.7", "--model", "gpt-3.5-turbo"} + + flags, err := Init() + assert.NoError(t, err) + assert.Equal(t, 0.7, flags.Temperature) + assert.Equal(t, "gpt-3.5-turbo", flags.Model) + assert.Equal(t, "analyze", flags.Pattern) // unchanged from YAML + assert.True(t, flags.Stream) // unchanged from YAML + }) + + // Test 3: Invalid YAML config + t.Run("Invalid YAML config", func(t *testing.T) { + badConfig := ` +temperature: "not a float" +model: 123 # should be string +` + badfile, err := os.CreateTemp("", "bad-config.*.yaml") + if err != nil { + t.Fatal(err) + } + defer os.Remove(badfile.Name()) + + if _, err := badfile.Write([]byte(badConfig)); err != nil { + t.Fatal(err) + } + if err := badfile.Close(); err != nil { + t.Fatal(err) + } + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = []string{"cmd", "--config", badfile.Name()} + + _, err = Init() + assert.Error(t, err) + }) +} \ No newline at end of file diff --git a/go.mod b/go.mod index 854630194..8d5363ecb 100644 --- a/go.mod +++ b/go.mod @@ -6,14 +6,15 @@ toolchain go1.23.1 require ( github.com/anaskhan96/soup v1.2.5 + github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 github.com/atotto/clipboard v0.1.4 github.com/gabriel-vasile/mimetype v1.4.6 github.com/gin-gonic/gin v1.10.0 github.com/go-git/go-git/v5 v5.12.0 github.com/go-shiori/go-readability v0.0.0-20241012063810-92284fa8a71f + github.com/google/generative-ai-go v0.18.0 github.com/jessevdk/go-flags v1.6.1 github.com/joho/godotenv v1.5.1 - github.com/liushuangls/go-anthropic/v2 v2.11.0 github.com/ollama/ollama v0.4.1 github.com/otiai10/copy v1.14.0 github.com/pkg/errors v0.9.1 @@ -22,6 +23,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/text v0.20.0 google.golang.org/api v0.205.0 + gopkg.in/yaml.v2 v2.4.0 ) require ( @@ -35,7 +37,6 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.1.2 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect - github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 // indirect github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect github.com/bytedance/sonic v1.12.4 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect @@ -58,7 +59,6 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/google/generative-ai-go v0.18.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect diff --git a/go.sum b/go.sum index 0fc0507a5..963b00b6a 100644 --- a/go.sum +++ b/go.sum @@ -158,8 +158,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/liushuangls/go-anthropic/v2 v2.11.0 h1:YKyxDWQNaKPPgtLCgBH+JqzuznNWw8ZqQVeSdQNDMds= -github.com/liushuangls/go-anthropic/v2 v2.11.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= @@ -360,6 +358,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=