Skip to content

Commit

Permalink
fix: make WriteConfig and SafeWriteConfig work as expected
Browse files Browse the repository at this point in the history
When a config file is set using SetConfigFile:
– WriteConfig failed if the file did not already exist,
– SafeWriteConfig did not use it.

Fixes #430
Fixes #433

Signed-off-by: Yann Soubeyrand <[email protected]>
  • Loading branch information
yann-soubeyrand committed Feb 9, 2025
1 parent 8b223a4 commit a97213d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
36 changes: 28 additions & 8 deletions viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ func (str UnsupportedConfigError) Error() string {
return fmt.Sprintf("Unsupported Config Type %q", string(str))
}

// ConfigFileError denotes failing to get configuration file.
type ConfigFileError struct {
error
}

// ConfigFileNotFoundError denotes failing to find configuration file.
type ConfigFileNotFoundError struct {
name, locations string
Expand Down Expand Up @@ -1575,21 +1580,24 @@ func (v *Viper) MergeConfigMap(cfg map[string]any) error {
func WriteConfig() error { return v.WriteConfig() }

func (v *Viper) WriteConfig() error {
filename, err := v.getConfigFile()
_, err := v.getConfigFile()
if err != nil {
return err
}
return v.writeConfig(filename, true)

return v.WriteConfigAs(v.configFile)
}

// SafeWriteConfig writes current configuration to file only if the file does not exist.
func SafeWriteConfig() error { return v.SafeWriteConfig() }

func (v *Viper) SafeWriteConfig() error {
if len(v.configPaths) < 1 {
return errors.New("missing configuration for 'configPath'")
_, err := v.getConfigFile()
if err != nil {
return err
}
return v.SafeWriteConfigAs(filepath.Join(v.configPaths[0], v.configName+"."+v.configType))

return v.SafeWriteConfigAs(v.configFile)
}

// WriteConfigAs writes current configuration to a given filename.
Expand Down Expand Up @@ -2004,11 +2012,23 @@ func (v *Viper) getConfigType() string {

func (v *Viper) getConfigFile() (string, error) {
if v.configFile == "" {
cf, err := v.findConfigFile()
var err error
v.configFile, err = v.findConfigFile()
if err != nil {
return "", err
if _, ok := err.(ConfigFileNotFoundError); !ok {
return "", err
}
if len(v.configPaths) < 1 {
return "", ConfigFileError{errors.New("missing configuration for 'configPath'")}
}
if v.configName == "" {
return "", ConfigFileError{errors.New("missing configuration for 'configName'")}
}
if v.configType == "" {
return "", ConfigFileError{errors.New("missing configuration for 'configType'")}
}
v.configFile = filepath.Join(v.configPaths[0], v.configName+"."+v.configType)
}
v.configFile = cf
}
return v.configFile, nil
}
Expand Down
77 changes: 71 additions & 6 deletions viper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,7 @@ func TestWrongDirsSearchNotFound(t *testing.T) {
v.AddConfigPath(`thispathaintthere`)

err := v.ReadInConfig()
assert.IsType(t, ConfigFileNotFoundError{"", ""}, err)
assert.ErrorAs(t, err, &ConfigFileError{})

// Even though config did not load and the error might have
// been ignored by the client, the default still loads
Expand All @@ -1639,7 +1639,7 @@ func TestWrongDirsSearchNotFoundForMerge(t *testing.T) {
v.AddConfigPath(`thispathaintthere`)

err := v.MergeInConfig()
assert.Equal(t, reflect.TypeOf(ConfigFileNotFoundError{"", ""}), reflect.TypeOf(err))
assert.ErrorAs(t, err, &ConfigFileError{})

// Even though config did not load and the error might have
// been ignored by the client, the default still loads
Expand Down Expand Up @@ -1731,7 +1731,7 @@ var jsonWriteExpected = []byte(`{
// name: steve
// `)

func TestWriteConfig(t *testing.T) {
func TestWriteConfigAs(t *testing.T) {
fs := afero.NewMemMapFs()
testCases := map[string]struct {
configName string
Expand Down Expand Up @@ -1809,7 +1809,7 @@ func TestWriteConfig(t *testing.T) {
}
}

func TestWriteConfigTOML(t *testing.T) {
func TestWriteConfigAsTOML(t *testing.T) {
fs := afero.NewMemMapFs()

testCases := map[string]struct {
Expand Down Expand Up @@ -1860,7 +1860,7 @@ func TestWriteConfigTOML(t *testing.T) {
}
}

func TestWriteConfigDotEnv(t *testing.T) {
func TestWriteConfigAsDotEnv(t *testing.T) {
fs := afero.NewMemMapFs()
testCases := map[string]struct {
configName string
Expand Down Expand Up @@ -1909,6 +1909,56 @@ func TestWriteConfigDotEnv(t *testing.T) {
}
}

func TestWriteConfig(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.AddConfigPath("/test")
v.SetConfigName("c")
v.SetConfigType("yaml")
require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample)))
require.NoError(t, v.WriteConfig())
read, err := afero.ReadFile(fs, "/test/c.yaml")
require.NoError(t, err)
assert.Equal(t, yamlWriteExpected, read)
}

func TestWriteConfigWithExplicitlySetFile(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.AddConfigPath("/test1")
v.SetConfigName("c1")
v.SetConfigType("yaml")
v.SetConfigFile("/test2/c2.yaml")
require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample)))
require.NoError(t, v.WriteConfig())
read, err := afero.ReadFile(fs, "/test2/c2.yaml")
require.NoError(t, err)
assert.Equal(t, yamlWriteExpected, read)
}

func TestWriteConfigWithMissingConfigPath(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("yaml")
require.EqualError(t, v.WriteConfig(), "missing configuration for 'configPath'")
}

func TestWriteConfigWithExistingFile(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
fs.Create("/test/c.yaml")
v.SetFs(fs)
v.AddConfigPath("/test")
v.SetConfigName("c")
v.SetConfigType("yaml")
err := v.WriteConfig()
require.NoError(t, err)
}

func TestSafeWriteConfig(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
Expand All @@ -1923,6 +1973,21 @@ func TestSafeWriteConfig(t *testing.T) {
assert.YAMLEq(t, string(yamlWriteExpected), string(read))
}

func TestSafeWriteConfigWithExplicitlySetFile(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.AddConfigPath("/test1")
v.SetConfigName("c1")
v.SetConfigType("yaml")
v.SetConfigFile("/test2/c2.yaml")
require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample)))
require.NoError(t, v.SafeWriteConfig())
read, err := afero.ReadFile(fs, "/test2/c2.yaml")
require.NoError(t, err)
assert.Equal(t, yamlWriteExpected, read)
}

func TestSafeWriteConfigWithMissingConfigPath(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
Expand All @@ -1946,7 +2011,7 @@ func TestSafeWriteConfigWithExistingFile(t *testing.T) {
assert.True(t, ok, "Expected ConfigFileAlreadyExistsError")
}

func TestSafeWriteAsConfig(t *testing.T) {
func TestSafeWriteConfigAs(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
Expand Down

0 comments on commit a97213d

Please sign in to comment.