Skip to content

Commit

Permalink
Refactor: Move SetViperStructDefaults to utils
Browse files Browse the repository at this point in the history
Signed-off-by: Vyom-Yadav <[email protected]>
  • Loading branch information
Vyom-Yadav committed Apr 3, 2024
1 parent 6f69727 commit 2fa1d82
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 87 deletions.
88 changes: 1 addition & 87 deletions internal/config/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
package server

import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"unicode"

"github.com/spf13/viper"

Expand Down Expand Up @@ -66,87 +62,5 @@ func DefaultConfigForTest() *Config {
func SetViperDefaults(v *viper.Viper) {
v.SetEnvPrefix("minder")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
setViperStructDefaults(v, "", Config{})
}

// setViperStructDefaults recursively sets the viper default values for the given struct.
//
// Per https://github.com/spf13/viper/issues/188#issuecomment-255519149, and
// https://github.com/spf13/viper/issues/761, we need to call viper.SetDefault() for each
// field in the struct to be able to use env var overrides. This also lets us use the
// struct as the source of default values, so yay?
func setViperStructDefaults(v *viper.Viper, prefix string, s any) {
structType := reflect.TypeOf(s)

for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if unicode.IsLower([]rune(field.Name)[0]) {
// Skip private fields
continue
}
if field.Tag.Get("mapstructure") == "" {
// Error, need a tag
panic(fmt.Sprintf("Untagged config struct field %q", field.Name))
}
valueName := strings.ToLower(prefix + field.Tag.Get("mapstructure"))
fieldType := field.Type

// Extract a default value the `default` struct tag
// we don't support all value types yet, but we can add them as needed
value := field.Tag.Get("default")

// Dereference one level of pointers, if present
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}

if fieldType.Kind() == reflect.Struct {
setViperStructDefaults(v, valueName+".", reflect.Zero(fieldType).Interface())
if _, ok := field.Tag.Lookup("default"); ok {
overrideViperStructDefaults(v, valueName, value)
}
continue
}

defaultValue := reflect.Zero(field.Type).Interface()
var err error // We handle errors at the end of the switch
//nolint:golint,exhaustive
switch fieldType.Kind() {
case reflect.String:
defaultValue = value
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
defaultValue, err = strconv.Atoi(value)
case reflect.Float64:
defaultValue, err = strconv.ParseFloat(value, 64)
case reflect.Bool:
defaultValue, err = strconv.ParseBool(value)
case reflect.Slice:
defaultValue = nil
default:
err = fmt.Errorf("unhandled type %s", fieldType)
}
if err != nil {
// This is effectively a compile-time error, so exit early
panic(fmt.Sprintf("Bad value for field %q (%s): %q", valueName, fieldType, err))
}

if err := v.BindEnv(strings.ToUpper(valueName)); err != nil {
panic(fmt.Sprintf("Failed to bind %q to env var: %v", valueName, err))
}
v.SetDefault(valueName, defaultValue)
}
}

func overrideViperStructDefaults(v *viper.Viper, prefix string, newDefaults string) {
overrides := map[string]any{}
if err := json.Unmarshal([]byte(newDefaults), &overrides); err != nil {
panic(fmt.Sprintf("Failed to parse overrides in %q: %v", prefix, err))
}

for key, value := range overrides {
// TODO: we don't do any fancy type checking here, so this could blow up later.
// I expect it will blow up at config-parse time, which should be earlier enough.
v.SetDefault(prefix+"."+key, value)
}
config.SetViperStructDefaults(v, "", Config{})
}
87 changes: 87 additions & 0 deletions internal/config/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
package config

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"unicode"

"github.com/spf13/pflag"
"github.com/spf13/viper"
Expand Down Expand Up @@ -187,3 +192,85 @@ func ReadConfigFromViper[CFG any](v *viper.Viper) (*CFG, error) {
}
return &cfg, nil
}

// SetViperStructDefaults recursively sets the viper default values for the given struct.
//
// Per https://github.com/spf13/viper/issues/188#issuecomment-255519149, and
// https://github.com/spf13/viper/issues/761, we need to call viper.SetDefault() for each
// field in the struct to be able to use env var overrides. This also lets us use the
// struct as the source of default values, so yay?
func SetViperStructDefaults(v *viper.Viper, prefix string, s any) {
structType := reflect.TypeOf(s)

for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if unicode.IsLower([]rune(field.Name)[0]) {
// Skip private fields
continue
}
if field.Tag.Get("mapstructure") == "" {
// Error, need a tag
panic(fmt.Sprintf("Untagged config struct field %q", field.Name))
}
valueName := strings.ToLower(prefix + field.Tag.Get("mapstructure"))
fieldType := field.Type

// Extract a default value the `default` struct tag
// we don't support all value types yet, but we can add them as needed
value := field.Tag.Get("default")

// Dereference one level of pointers, if present
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}

if fieldType.Kind() == reflect.Struct {
SetViperStructDefaults(v, valueName+".", reflect.Zero(fieldType).Interface())
if _, ok := field.Tag.Lookup("default"); ok {
overrideViperStructDefaults(v, valueName, value)
}
continue
}

defaultValue := reflect.Zero(field.Type).Interface()
var err error // We handle errors at the end of the switch
//nolint:golint,exhaustive
switch fieldType.Kind() {
case reflect.String:
defaultValue = value
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
defaultValue, err = strconv.Atoi(value)
case reflect.Float64:
defaultValue, err = strconv.ParseFloat(value, 64)
case reflect.Bool:
defaultValue, err = strconv.ParseBool(value)
case reflect.Slice:
defaultValue = nil
default:
err = fmt.Errorf("unhandled type %s", fieldType)
}
if err != nil {
// This is effectively a compile-time error, so exit early
panic(fmt.Sprintf("Bad value for field %q (%s): %q", valueName, fieldType, err))
}

if err := v.BindEnv(strings.ToUpper(valueName)); err != nil {
panic(fmt.Sprintf("Failed to bind %q to env var: %v", valueName, err))
}
v.SetDefault(valueName, defaultValue)
}
}

func overrideViperStructDefaults(v *viper.Viper, prefix string, newDefaults string) {
overrides := map[string]any{}
if err := json.Unmarshal([]byte(newDefaults), &overrides); err != nil {
panic(fmt.Sprintf("Failed to parse overrides in %q: %v", prefix, err))
}

for key, value := range overrides {
// TODO: we don't do any fancy type checking here, so this could blow up later.
// I expect it will blow up at config-parse time, which should be earlier enough.
v.SetDefault(prefix+"."+key, value)
}
}

0 comments on commit 2fa1d82

Please sign in to comment.