Skip to content

Commit

Permalink
context: avoid corrupt file writes
Browse files Browse the repository at this point in the history
Write to a tempfile then move, so that if the
process dies mid-write it doesn't corrupt the store.

Also improve error messaging so that if a file does
get corrupted, the user has some hope of figuring
out which file is broken.

For background, see:
docker/for-win#13180
docker/for-win#12561

For a repro case, see:
https://github.com/nicks/contextstore-sandbox

Signed-off-by: Nick Santos <[email protected]>
  • Loading branch information
nicks committed Feb 21, 2023
1 parent dfb36ea commit 15db6d0
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
35 changes: 35 additions & 0 deletions cli/context/store/io_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package store
import (
"errors"
"io"
"os"
"path/filepath"
)

// LimitedReader is a fork of io.LimitedReader to override Read.
Expand All @@ -27,3 +29,36 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) {
l.N -= int64(n)
return n, err
}

// Write the bytes to a temporary file, then move the file to
// the destination path.
//
// This helps prevent corrupt files if the process is killed mid-write.
// Background:
// https://github.com/docker/for-win/issues/13180
// https://github.com/docker/for-win/issues/12561
func writeTempThenMove(dest string, bytes []byte, mode os.FileMode) error {
f, err := os.CreateTemp(filepath.Dir(dest), filepath.Base(dest))
if err != nil {
return err
}
name := f.Name()

_, err = f.Write(bytes)
if err != nil {
_ = f.Close()
_ = os.Remove(name)
return err
}
err = f.Close()
if err != nil {
_ = os.Remove(name)
return err
}
err = os.Chmod(name, mode)
if err != nil {
_ = os.Remove(name)
return err
}
return os.Rename(name, dest)
}
12 changes: 7 additions & 5 deletions cli/context/store/metadatastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package store

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -35,7 +36,7 @@ func (s *metadataStore) createOrUpdate(meta Metadata) error {
if err != nil {
return err
}
return os.WriteFile(filepath.Join(contextDir, metaFile), bytes, 0o644)
return writeTempThenMove(filepath.Join(contextDir, metaFile), bytes, 0o644)
}

func parseTypedOrMap(payload []byte, getter TypeGetter) (interface{}, error) {
Expand Down Expand Up @@ -65,7 +66,8 @@ func (s *metadataStore) get(name string) (Metadata, error) {
}

func (s *metadataStore) getByID(id contextdir) (Metadata, error) {
bytes, err := os.ReadFile(filepath.Join(s.contextDir(id), metaFile))
fileName := filepath.Join(s.contextDir(id), metaFile)
bytes, err := os.ReadFile(fileName)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return Metadata{}, errdefs.NotFound(errors.Wrap(err, "context not found"))
Expand All @@ -77,15 +79,15 @@ func (s *metadataStore) getByID(id contextdir) (Metadata, error) {
Endpoints: make(map[string]interface{}),
}
if err := json.Unmarshal(bytes, &untyped); err != nil {
return Metadata{}, err
return Metadata{}, fmt.Errorf("parsing %s: %v", fileName, err)
}
r.Name = untyped.Name
if r.Metadata, err = parseTypedOrMap(untyped.Metadata, s.config.contextType); err != nil {
return Metadata{}, err
return Metadata{}, fmt.Errorf("parsing %s: %v", fileName, err)
}
for k, v := range untyped.Endpoints {
if r.Endpoints[k], err = parseTypedOrMap(v, s.config.endpointTypes[k]); err != nil {
return Metadata{}, err
return Metadata{}, fmt.Errorf("parsing %s: %v", fileName, err)
}
}
return r, err
Expand Down
28 changes: 28 additions & 0 deletions cli/context/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"bytes"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"os"
"path"
"path/filepath"
"testing"

"github.com/docker/docker/errdefs"
Expand Down Expand Up @@ -230,3 +232,29 @@ func TestImportZipInvalid(t *testing.T) {
err = Import("zipInvalid", s, r)
assert.ErrorContains(t, err, "unexpected context file")
}


func TestCorruptMetadata(t *testing.T) {
tempDir := t.TempDir()
s := New(tempDir, testCfg)
err := s.CreateOrUpdate(
Metadata{
Endpoints: map[string]interface{}{
"ep1": endpoint{Foo: "bar"},
},
Metadata: context{Bar: "baz"},
Name: "source",
})
assert.NilError(t, err)

// Simulate the meta.json file getting corrupted
// by some external process.
contextDir := s.meta.contextDir(contextdirOf("source"))
contextFile := filepath.Join(contextDir, metaFile)
err = os.WriteFile(contextFile, nil, 0o600)
assert.NilError(t, err)

// Assert that the error message gives the user some clue where to look.
_, err = s.GetMetadata("source")
assert.ErrorContains(t, err, fmt.Sprintf("parsing %s: unexpected end of JSON input", contextFile))
}
2 changes: 1 addition & 1 deletion cli/context/store/tlsstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (s *tlsStore) createOrUpdate(name, endpointName, filename string, data []by
if err := os.MkdirAll(endpointDir, 0o700); err != nil {
return err
}
return os.WriteFile(filepath.Join(endpointDir, filename), data, 0o600)
return writeTempThenMove(filepath.Join(endpointDir, filename), data, 0o600)
}

func (s *tlsStore) getData(name, endpointName, filename string) ([]byte, error) {
Expand Down

0 comments on commit 15db6d0

Please sign in to comment.